TP切分场景下,实现allgather和matmul的融合,融合算子内部实现通信和计算流水并行。
使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。
npu_all_gather_base_mm(Tensor input, Tensor x2, str hcom, int world_size, *, Tensor? bias=None, int gather_index=0, bool gather_output=True, int comm_turn=0) -> (Tensor, Tensor)
两个输出,均为Tensor类型:(Tensor, Tensor)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = default_pg.get_hccl_comm_name(rank) tensor_allgather_shape = x1_shape single_shape = [x1_shape[0] // world_size, x1_shape[1]] input_ = torch.randn(single_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output, gather_out = torch_npu.npu_all_gather_base_mm(input_, weight, hcomm_info, world_size) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp class ALLGATHER_MM_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input, weight, hcomm_info, world_size, gather_output): output, gather_output = torch_npu.npu_all_gather_base_mm(input, weight, hcomm_info, world_size, gather_output=gather_output) return output, gather_output def define_model(model, graph_type): import torchair if graph_type == 1: # 传统入图模式,静态shape+在线编译场景 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=False) elif graph_type == 2: # ACLNN入图模式,动态shape+二进制 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=True) else: print("Error type") return model def get_graph(input, weight, hcomm_info, world_size, gather_output): model = ALLGATHER_MM_GRAPH_Model() model = define_model(model, 2) model_output = model(input, weight, hcomm_info, world_size, gather_output=gather_output) output_npu = model_output[0] gather_output_npu = model_output[1] return output_npu, gather_output_npu def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = default_pg.get_hccl_comm_name(rank) single_shape = [x1_shape[0] // world_size, x1_shape[1]] input = torch.randn(single_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() is_gather_out = True output, gather_out = get_graph(input, weight, hcomm_info, world_size, is_gather_out) print("output:", output) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize) |