下载
中文
注册

torch_npu.npu_mm_all_reduce_base

功能描述

TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

1
npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None Tensor? pertoken_scale=None, Tensor? comm_quant_scale_1=None, Tensor? comm_quant_scale_2=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor

参数说明

  • x1:Device侧的Tensor类型,数据类型支持INT8、FLOAT16、BFLOAT16。数据格式支持ND,输入shape支持2维或者3维。
  • x2:Device侧的Tensor类型,数据类型支持FLOAT16、INT8、BFLOAT16,数据格式支持NZ(昇腾亲和排布格式)、ND。非量化场景,数据类型需要和x1保持一致,输入shape维度第0维和x1的最后一维保持一致。
  • hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
  • *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • reduce_op:Host侧的String类型,reduce操作类型,当前版本仅支持'sum',默认值:'sum'。
  • bias:Device侧的Tensor类型,可选输入,数据类型支持INT32、FLOAT16、BFLOAT16,数据格式支持ND。bias当前仅支持一维,且维度大小与output/x2的最后一维大小相同。
  • antiquant_scale:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。伪量化场景数据类型需要和x1保持一致。
    • per-tensor场景:shape为[1]。
    • per-channel场景:shape为[1,n]或者[n],n为x2最后一维的大小。
    • per-group场景:shape为[ceil(k, antiquant_group_size), n]。其中k为x2第一维的大小,n为x2最后一维的大小,antiquant_group_size为伪量化场景对输入x2进行反量化计算的groupSize输入

      ceil(k, antiquant_group_size)的计算逻辑为:(k + antiquant_group_size - 1) / antiquant_group_size,并对计算结果取整数部分。

  • antiquant_offset:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。数据类型、shape需要和antiquant_scale保持一致。
  • x3:Device侧的Tensor类型,可选输入,matmul计算后的偏移。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。数据类型、shape需要和输出output保持一致。
  • dequant_scale:Device侧的Tensor类型,可选输入,matmul计算后的去量化系数。数据类型支持INT64、UINT64、BFLOAT16、FLOAT32;数据格式支持ND。
    • per-tensor场景:shape为[1]。
    • per-channel场景:shape为[n]/[1,n],n为x2最后一维的大小。
  • pertoken_scale:Device侧的Tensor类型,可选输入,matmul计算后的per-token去量化系数。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT32。当x1为[m,k]时pertoken_scale shape为[m];当x1为[b, s, k]时pertoken_scale shape为[b*s]。
  • comm_quant_scale_1: Device侧的Tensor类型,可选输入,alltoall通信前后的量化、去量化系数。支持FLOAT16、BFLOAT16,支持ND格式。x2为[k, n]时shape为[1, n]或[n],用户需保证每张卡上数据保持一致且正确。
  • comm_quant_scale_2: Device侧的Tensor类型,可选输入,allgather通信前后的量化、去量化系数。支持FLOAT16、BFLOAT16,支持ND格式。x2为[k, n]时shape为[1, n]或[n],用户需保证每张卡上数据保持一致且正确。
  • comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。
  • antiquant_group_size:Host侧的int类型,表示伪量化pre-group算法模式下,对输入x2进行反量化计算的groupSize输入,描述一组反量化参数对应的待反量化数据量在k轴方向的大小。当伪量化算法模式不为pre_group时传入0;当伪量化算法模式为pre_group时传入值的范围为[32, min(k-1, INT_MAX)]且值要求是32的倍数,其中k为x2第一维的大小。默认值0,为0则表示非per-group场景。

输出说明

Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景输出数据类型为FLOAT16或BFLOAT16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。

约束说明

  • 该接口仅在推理场景下使用。
  • 增量场景不使能该融合算子,全量场景使能该融合算子。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • 输入x1可为2维或者3维、x2必须是2维,分别为(b, s, k)/(m, k), (k, n),k轴满足mm算子入参要求,k轴相等。bias当前仅支持一维,且维度大小与output的最后一维大小相同。x3的shape与output的shape相同。
  • x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。
  • antiquant_group_size中k值的范围与matmul一致,为[1,65535],INT_MAX大于(k-1)。
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
    • 数据类型支持BFLOAT16。
    • x1、x2不支持为空tensor。
    • 支持1、2、4、8卡,并且仅支持hccs链路all mesh组网。
    • 非量化场景下,m、k、n的取值范围均为[1, 2147483647]。
    • comm_quant_scale_1,comm_quant_scale_2的shape应保持一致,dtype与输出的dtype保持一致,且只在全量化场景支持。
  • 全量化场景:m取值范围均为[1, 2147483647],x1、x2的最后一维范围为[1, 65535],即k的取值范围为[1, 65535]、仅当x2(shape=[n,k])为转置时n可以大于65535。
  • 伪量化场景:m取值范围均为[1, 2147483647],k、n的取值范围为[1, 65535]。
  • 一个模型中的通算融合MC2算子,仅支持相同通信域。
  • 在长序列场景,随着b/s或者m的增大,可能出现内存不足或者计算超时。
  • 不同场景下数据类型支持情况:
    表1 非量化场景

    产品型号

    x1

    x2

    bias

    x3

    output(输出)

    antiquant_scale

    antiquant_offset

    dequant_scale

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    FLOAT16

    FLOAT16

    FLOAT16

    FLOAT16

    FLOAT16

    None

    None

    None

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    BFLOAT16

    BFLOAT16

    BFLOAT16

    BFLOAT16

    BFLOAT16

    None

    None

    None

    表2 伪量化场景

    产品型号

    x1

    x2

    bias

    x3

    output(输出)

    antiquant_scale

    antiquant_offset

    dequant_scale

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    FLOAT16

    INT8

    FLOAT16

    FLOAT16

    FLOAT16

    FLOAT16

    FLOAT16

    None

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    BFLOAT16

    INT8

    BFLOAT16

    BFLOAT16

    BFLOAT16

    BFLOAT16

    BFLOAT16

    None

    表3 全量化场景

    产品型号

    x1

    x2

    bias

    x3

    output(输出)

    antiquant_scale

    antiquant_offset

    dequant_scale

    pertoken_scale

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    INT8

    INT8

    INT32

    FLOAT16

    FLOAT16

    None

    None

    UINT64或INT64

    None

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    INT8

    INT8

    INT32

    BFLOAT16

    BFLOAT16

    None

    None

    BFLOAT16

    None

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    INT8

    INT8

    INT32

    FLOAT16

    FLOAT16

    None

    None

    FLOAT32

    FLOAT32

    Atlas A2 训练系列产品/Atlas 800I A2 推理产品

    INT8

    INT8

    INT32

    BFLOAT16

    BFLOAT16

    None

    None

    BFLOAT16

    FLOAT32

    全量化场景:若dequant_scale需要以FP32类型传入,在调用torch_npu.npu_mm_all_reduce_base前,需通过torch_npu.npu_trans_quant_param接口对dequant_scale进行处理为INT64类型(处理方法见对应的接口使用说明)。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcom_info = default_pg.get_hccl_comm_name(rank)
    
        input_ = torch.randn(x1_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum')
        print("output: ", output)
    
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
    
        mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
    
  • 图模式调用

    非量化、伪量化、全量化使能NZ调用示例如下:

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import numpy as np
    class MM_ALLREDUCE_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale):
            output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1,
                                                          x2=x2,
                                                          hcom=hcom,
                                                          reduce_op=reduce_op,
                                                          bias=bias,
                                                          antiquant_scale=antiquant_scale,
                                                          antiquant_offset=antiquant_offset,
                                                          x3=x3,
                                                          dequant_scale=dequant_scale
                                                          )
            return output_npu
    
    class MM_ALLREDUCE_A8W8_GRAPH_Model(MM_ALLREDUCE_GRAPH_Model):
        def __init__(self):
            super().__init__()
        def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale):
            output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1,
                                                          x2=x2.t(),
                                                          hcom=hcom,
                                                          reduce_op=reduce_op,
                                                          bias=bias,
                                                          antiquant_scale=antiquant_scale,
                                                          antiquant_offset=antiquant_offset,
                                                          x3=x3,
                                                          dequant_scale=dequant_scale
                                                          )
            return output_npu
    
    def define_model(model, graph_type):
        import torchair
        if graph_type == 1:  # 传统入图模式,静态shape+在线编译场景
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=False)
        elif graph_type == 2:  # ACLNN入图模式,动态shape+二进制
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=True)
        else:
            print("Error type")
        return model
    
    def get_graph(input, weight, hcomm_info, dequant_scale, bias, antiquant_scale, antiquant_offset, x3):
        model = MM_ALLREDUCE_A8W8_GRAPH_Model()
        model = define_model(model, 2) # 1:静态入图;2:动态入图;
        output = model(x1=input, x2=weight, hcom=hcomm_info, reduce_op="sum", bias=bias, antiquant_scale=antiquant_scale,
                       antiquant_offset=antiquant_offset, x3=x3, dequant_scale=dequant_scale)
        return output
    
    def run_mc2_a16w16(x1_shape, x2_shape, hcom_info):
        np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16)
        np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.float16)
        input = torch.tensor(np_input).npu()
        weight = torch.tensor(np_weight).npu()
        output_a16w16 = get_graph(input, weight, hcom_info, None, None, None, None, None)
        return output_a16w16
    
    def run_mc2_a8w8(x1_shape, x2_shape, hcom_info):
        np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.int8)
        np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8)
        input = torch.tensor(np_input).npu()
        weight = torch.tensor(np_weight).npu()
        weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29)
        dequant_scale = torch.randn(x2_shape[0], dtype=torch.float32).uniform_(float(-10), float(10)).npu()
        dequant_scale = torch_npu.npu_trans_quant_param(dequant_scale)
        output_a8w8 = get_graph(input, weight_nz, hcom_info, dequant_scale, None, None, None, None)
        return output_a8w8
    
    def run_mc2_a16w8(x1_shape, x2_shape, hcom_info):
        np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16)
        np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8)
        input = torch.tensor(np_input).npu()
        weight = torch.tensor(np_weight).npu()
        weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29)
        antiquant_scale = torch.randn(x2_shape[0], dtype=torch.float16).uniform_(float(-1), float(1)).npu()
        antiquant_offset = torch.ones(x2_shape[0], dtype=torch.float16).npu()
        output_a16w8 = get_graph(input, weight_nz, hcom_info, None, None, antiquant_scale, antiquant_offset, None)
        return output_a16w8
    
    def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, op_type):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcom_info = default_pg.get_hccl_comm_name(rank)
        output = None
        # 非量化调用
        if op_type == "a16w16":
            output = run_mc2_a16w16(x1_shape, x2_shape, hcom_info)
        # 伪量化调用
        if op_type == "a16w8":
            output = run_mc2_a16w8(x1_shape, x2_shape, hcom_info)
        # 全量化调用
        if op_type == "a8w8":
            output = run_mc2_a8w8(x1_shape, x2_shape, hcom_info)
        print("output:", output)
    if __name__ == "__main__":
        worksize = 2
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [1280, 5120]
        x2_shape = [640, 5120] 
        op_type = "a16w8" # Options: a16w16, a16w8, a8w8
        mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, op_type), nprocs=worksize)