torch_npu.npu_all_gather_base_mm
功能说明
TP切分场景下,实现allgather和matmul的融合,融合算子内部实现通信和计算流水并行。
使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。
接口原型
npu_all_gather_base_mm(Tensor input, Tensor x2, str hcom, int world_size, *, Tensor? bias=None, int gather_index=0, bool gather_output=True, int comm_turn=0) -> (Tensor, Tensor)
参数说明
- input:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND格式,输入shape支持2维。
- x2:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND格式,数据类型需要和input保持一致,输入shape维度和input保持一致。
- hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
- world_size:Host侧的int类型,通信域内的rank总数,仅支持为2、4、8。
- *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- bias:Device侧的Tensor类型,可选输入,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND格式。数据类型需要和input保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入为非0的场景。
- gather_index:Host侧的int类型,表示gather操作对象,0:对input做gather,1:对x2做gather。默认值0。当前版本仅支持输入0。
- gather_output:Host侧的bool类型,表示是否需要gather输出。默认值true。
- comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。
输出说明
两个输出,均为Tensor类型:(Tensor, Tensor)
- 第一个输出是allgather+matmul的结果。
- 第二个输出是allgather的结果。
约束说明
- 输入input、x2必须是2维,分别为(m, k),(k, n),轴满足matmul算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
- x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。
- Atlas A2 训练系列产品:支持2、4、8卡, 并且仅支持hccs链路all mesh组网。
- 一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce),仅支持相同通信域
支持的型号
- Atlas A2 训练系列产品
调用示例
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = default_pg.get_hccl_comm_name(rank) tensor_allgather_shape = x1_shape single_shape = [x1_shape[0] // world_size, x1_shape[1]] input_ = torch.randn(single_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output, gather_out = torch_npu.npu_all_gather_base_mm(input_, weight, hcomm_info, world_size) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
父主题: torch_npu