torch_npu.npu_mm_all_reduce_base
功能描述
TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。
接口原型
npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor
参数说明
- x1:Device侧的Tensor类型,支持float16、bfloat16、int8,支持ND,输入shape支持2维或者3维。
- x2:Device侧的Tensor类型,支持float16、bfloat16、int8,支持ND,非量化场景,数据类型需要和x1保持一致,输入shape维度第0维和x1的最后一维保持一致。
- hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
- *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- reduce_op:Host侧的String类型,reduce操作类型,当前版本仅支持'sum',默认值:'sum'。
- bias:Device侧的Tensor类型,可选输入,支持int32、float16、bfloat16,支持ND格式。bias当前仅支持一维,且维度大小与output/x2的最后一维大小相同。
- antiquant_scale:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,支持float16、bfloat16,支持ND格式。伪量化场景数据类型需要和x1保持一致。antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n]。其中n为x2最后一维的大小。per-group场景支持shape为[ceil(k, antiquant_group_size), n](具体计算逻辑见约束说明)。其中k为x2第一维的大小,n为x2最后一维的大小,antiquant_group_size为伪量化场景对输入x2进行反量化计算的groupSize输入。
- antiquant_offset:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,支持float16、bfloat16,支持ND格式。数据类型需要和antiquant_scale保持一致。shape与antiquant_scale保持一致。
- x3:Device侧的Tensor类型,可选输入,matmul计算后的偏移。支持float16、bfloat16。支持ND格式。数据类型需要和输出output保持一致。shape与output的shape相同。伪量化场景暂不支持处理x3。
- dequant_scale:Device侧的Tensor类型,可选输入,matmul计算后的去量化系数。支持int64、uint64、bfloat16,支持ND格式。shape在per-tensor场景为[1],per-channel场景为[n]/[1,n],其中n为x2最后一维的大小。
- comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。
- antiquant_group_size:Host侧的int类型,表示伪量化pre-group算法模式下,对输入x2进行反量化计算的groupSize输入,描述一组反量化参数对应的待反量化数据量在k轴方向的大小。当伪量化算法模式不为pre_group时传入0;当伪量化算法模式为pre_group时传入值的范围为[32, min(k-1, INT_MAX)]且值要求是32的倍数,其中k为x2第一维的大小。默认值0,为0则表示非per-group场景。
输出说明
Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景为float16或者bfloat16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。
约束说明
- 输入x1可为2维或者3维、x2必须是2维,分别为(b, s, k)/(m, k), (k, n),k轴满足mm算子入参要求,k轴相等。bias当前仅支持一维,且维度大小与output的最后一维大小相同。x3的shape与output的shape相同。antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n],per-group场景支持shape为(ceil(k, antiquant_group_size), n)。antiquant_offset的shape与antiquant_scale一致。dequant_scale的shape在per-tensor场景为[1],per-channel场景为[n]/[1,n]。伪量化场景:m、k、n的范围为[1, 65535]。
- [ceil(k, antiquant_group_size), n]中的ceil(k, antiquant_group_size)计算逻辑为:(k + antiquant_group_size - 1) / antiquant_group_size,并对计算结果取整数部分。
- x1、x2不能为空tensor。
- 非量化场景,x1、x2、bias、output的数据类型保持一致,可为float16或者bfloat16,antiquant_scale、antiquant_offset、dequant_scale为None。
- 伪量化场景,x1、bias、x3、antiquant_scale、antiquant_offset,output的数据类型保持一致,可为float16或者bfloat16,x2的数据类型为int8,dequant_scale为None。
- 全量化场景,x1、x2的数据类型为int8,dequant_scale的数据类型为int64、uint64或者bfloat16。dequant_scale类型为int64、uint64时,output数据类型为float16;dequant_scale类型为bfloat16时,output数据类型为bfloat16;bias数据类型为int32;antiquant_scale、antiquant_offset为None。仅输出为bfloat16时,支持传入x3。另外,若dequant_scale需要以int64类型传入,在调用torch_npu.npu_mm_all_reduce_base()前,需通过torch_npu.npu_trans_quant_param()接口对dequant_scale进行处理(处理方法见对应的接口使用说明)。
- antiquant_group_size中k值的范围与matmul一致,为[1,65535],INT_MAX大于(k-1)。
- x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。
- Atlas A2 训练系列产品支持2、4、8卡。
- 增量场景不使能MC2,全量场景使能MC2。
支持的PyTorch版本
- PyTorch 2.1
- PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品
调用示例
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcom_info = default_pg.get_hccl_comm_name(rank) input_ = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum') print("output: ", output) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
父主题: torch_npu