torch_npu.npu_mm_all_reduce_base
功能描述
TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。
使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。
接口原型
npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None Tensor? pertoken_scale=None, Tensor? comm_quant_scale_1=None, Tensor? comm_quant_scale_2=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor
参数说明
- x1:Device侧的Tensor类型,Atlas A2训练系列产品/Atlas 800I A2推理产品数据类型支持INT8、FLOAT16、BFLOAT16(仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持);数据格式支持ND,输入shape支持2维或者3维。
- x2:Device侧的Tensor类型,Atlas A2训练系列产品/Atlas 800I A2推理产品数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND/NZ。非量化场景,数据类型需要和x1保持一致,输入shape维度第0维和x1的最后一维保持一致。
- hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
- *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- reduce_op:Host侧的String类型,reduce操作类型,当前版本仅支持'sum',默认值:'sum'。
- bias:Device侧的Tensor类型,可选输入,数据类型支持INT32、FLOAT16、BFLOAT16(仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持),数据格式支持ND格式。bias当前仅支持一维,且维度大小与output/x2的最后一维大小相同。
- antiquant_scale:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,数据类型支持FLOAT16、BFLOAT16(仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持),数据格式支持ND格式。伪量化场景数据类型需要和x1保持一致。antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n]。其中n为x2最后一维的大小。per-group场景支持shape为[ceil(k, antiquant_group_size), n](具体计算逻辑见约束说明)。其中k为x2第一维的大小,n为x2最后一维的大小,antiquant_group_size为伪量化场景对输入x2进行反量化计算的groupSize输入。
- antiquant_offset:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,数据类型支持FLOAT16、BFLOAT16(仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持),数据格式支持ND格式。数据类型需要和antiquant_scale保持一致。shape与antiquant_scale保持一致。
- x3:Device侧的Tensor类型,可选输入,matmul计算后的偏移。数据类型支持FLOAT16、BFLOAT16(仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持),数据格式支持ND格式。数据类型需要和输出output保持一致。shape与output的shape相同。
- dequant_scale:Device侧的Tensor类型,可选输入,matmul计算后的去量化系数。Atlas A2训练系列产品/Atlas 800I A2推理产品支持INT64、UINT64、BFLOAT16、FLOAT32;数据格式支持ND格式。shape在per-tensor场景为[1],per-channel场景为[n]/[1,n],其中n为x2最后一维的大小。
- pertoken_scale:Device侧的Tensor类型,可选输入,matmul计算后的per-token去量化系数。Atlas A2训练系列产品/Atlas 800I A2推理产品支持FLOAT32,x1为[m,k]时shape为[m],x1为[b, s, k]时shape为[b*s]。
- comm_quant_scale_1: Device侧的Tensor类型,可选输入,alltoall通信前后的量化、去量化系数。支持FLOAT16、BFLOAT16,支持ND格式。x2为[k, n]时shape为[1, n]或[n],用户需保证每张卡上数据保持一致且正确。
- comm_quant_scale_2: Device侧的Tensor类型,可选输入,allgather通信前后的量化、去量化系数。支持FLOAT16、BFLOAT16,支持ND格式。x2为[k, n]时shape为[1, n]或[n],用户需保证每张卡上数据保持一致且正确。
- comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。
- antiquant_group_size:Host侧的int类型,表示伪量化pre-group算法模式下,对输入x2进行反量化计算的groupSize输入,描述一组反量化参数对应的待反量化数据量在k轴方向的大小。当伪量化算法模式不为pre_group时传入0;当伪量化算法模式为pre_group时传入值的范围为[32, min(k-1, INT_MAX)]且值要求是32的倍数,其中k为x2第一维的大小。默认值0,为0则表示非per-group场景。
输出说明
Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景Atlas A2训练系列产品/Atlas 800I A2推理产品支持输出为FLOAT16或者BFLOAT16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。
约束说明
- 该融合算子仅在推理场景使用。
- BFLOAT16数据类型仅Atlas A2训练系列产品/Atlas 800I A2推理产品支持。
- 输入x1可为2维或者3维、x2必须是2维,分别为(b, s, k)/(m, k), (k, n),k轴满足mm算子入参要求,k轴相等。bias当前仅支持一维,且维度大小与output的最后一维大小相同。x3的shape与output的shape相同。
- Atlas A2训练系列产品/Atlas 800I A2推理产品x1、x2不能为空tensor。
- Atlas A2训练系列产品/Atlas 800I A2推理产品的非量化场景:m、k、n的取值范围均为[1, 2147483647]。
- 全量化场景:m取值范围均为[1, 2147483647],x1、x2的最后一维范围为[1, 65535],即k的取值范围为[1, 65535]、仅当x2(shape=[n,k])为转置时n可以大于65535。
- 伪量化场景:m取值范围均为[1, 2147483647],k、n的取值范围为[1, 65535]。
- antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n],per-group场景支持shape为(ceil(k, antiquant_group_size), n)。antiquant_offset的shape与antiquant_scale一致。dequant_scale的shape在per-tensor场景为[1],per-channel场景为[n]/[1,n]。
- per-token场景下pertoken_scale的shape在x1二维时为[m],x1三维时为[b*s]。
- [ceil(k, antiquant_group_size), n]中的ceil(k, antiquant_group_size)计算逻辑为:(k + antiquant_group_size - 1) / antiquant_group_size,并对计算结果取整数部分。
- 不同场景数据类型支持情况:
- 非量化场景:
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为FLOAT16、x2为FLOAT16、bias为FLOAT16、x3为FLOAT16、output为FLOAT16,antiquant_scale、antiquant_offset、dequant_scale为None。
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为BFLOAT16、x2为BFLOAT16、bias为BFLOAT16、x3为BFLOAT16、output为BFLOAT16,antiquant_scale、antiquant_offset、dequant_scale为None。
- 伪量化场景:
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为FLOAT16、x2为INT8、bias为FLOAT16、x3为FLOAT16、output为FLOAT16,antiquant_scale为FLOAT16、antiquant_offset为FLOAT16、dequant_scale为None。
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为BFLOAT16、x2为INT8、bias为BFLOAT16、x3为BFLOAT16、output为BFLOAT16,antiquant_scale为BFLOAT16、antiquant_offset为BFLOAT16、dequant_scale为None。
- 全量化场景:
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为INT8、x2为INT8、bias为INT32、x3为FLOAT16、output为FLOAT16,antiquant_scale为None、antiquant_offset为None、dequant_scale为UINT64或INT64,pertoken_scale为None。
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为INT8、x2为INT8、bias为INT32、x3为BFLOAT16、output为BFLOAT16,antiquant_scale为None、antiquant_offset为None、dequant_scale为BFLOAT16,pertoken_scale为None。
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为INT8、x2为INT8、bias为INT32、x3为FLOAT16、output为FLOAT16,antiquant_scale为None、antiquant_offset为None、dequant_scale为FLOAT32、pertoken_scale为FLOAT32。
- Atlas A2训练系列产品/Atlas 800I A2推理产品中x1为INT8、x2为INT8、bias为INT32、x3为BFLOAT16、output为BFLOAT16,antiquant_scale为None、antiquant_offset为None、dequant_scale为BFLOAT16、pertoken_scale为FLOAT32。
- 若dequant_scale需要以FP32类型传入,在调用torch_npu.npu_mm_all_reduce_base()前,需通过torch_npu.npu_trans_quant_param()接口对dequant_scale进行处理为INT64类型(处理方法见对应的接口使用说明)。
- 非量化场景:
- antiquant_group_size中k值的范围与matmul一致,为[1,65535],INT_MAX大于(k-1)。
- x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。
- Atlas A2训练系列产品/Atlas 800I A2推理产品支持1、2、4、8卡,并且仅支持hccs链路all mesh组网。
- 增量场景不使能该融合算子,全量场景使能该融合算子。
- 一个模型中的通算融合MC2算子,仅支持相同通信域。
- 在长序列场景,随着b/s或者m的增大,可能出现内存不足或者计算超时。
- comm_quant_scale_1,comm_quant_scale_2的shape应保持一致,dtype与输出的dtype保持一致,且只在Atlas A2训练系列产品/Atlas 800I A2推理产品全量化场景支持。
支持的型号
- Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
- 单算子调用示例:
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcom_info = default_pg.get_hccl_comm_name(rank) input_ = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum') print("output: ", output) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
父主题: torch_npu