TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。
使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。
npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None Tensor? pertoken_scale=None, Tensor? comm_quant_scale_1=None, Tensor? comm_quant_scale_2=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor
ceil(k, antiquant_group_size)的计算逻辑为:(k + antiquant_group_size - 1) / antiquant_group_size,并对计算结果取整数部分。
Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景输出数据类型为FLOAT16或BFLOAT16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。
产品型号 |
x1 |
x2 |
bias |
x3 |
output(输出) |
antiquant_scale |
antiquant_offset |
dequant_scale |
---|---|---|---|---|---|---|---|---|
FLOAT16 |
FLOAT16 |
FLOAT16 |
FLOAT16 |
FLOAT16 |
None |
None |
None |
|
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
None |
None |
None |
产品型号 |
x1 |
x2 |
bias |
x3 |
output(输出) |
antiquant_scale |
antiquant_offset |
dequant_scale |
---|---|---|---|---|---|---|---|---|
FLOAT16 |
INT8 |
FLOAT16 |
FLOAT16 |
FLOAT16 |
FLOAT16 |
FLOAT16 |
None |
|
BFLOAT16 |
INT8 |
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
BFLOAT16 |
None |
产品型号 |
x1 |
x2 |
bias |
x3 |
output(输出) |
antiquant_scale |
antiquant_offset |
dequant_scale |
pertoken_scale |
---|---|---|---|---|---|---|---|---|---|
INT8 |
INT8 |
INT32 |
FLOAT16 |
FLOAT16 |
None |
None |
UINT64或INT64 |
None |
|
INT8 |
INT8 |
INT32 |
BFLOAT16 |
BFLOAT16 |
None |
None |
BFLOAT16 |
None |
|
INT8 |
INT8 |
INT32 |
FLOAT16 |
FLOAT16 |
None |
None |
FLOAT32 |
FLOAT32 |
|
INT8 |
INT8 |
INT32 |
BFLOAT16 |
BFLOAT16 |
None |
None |
BFLOAT16 |
FLOAT32 |
全量化场景:若dequant_scale需要以FP32类型传入,在调用torch_npu.npu_mm_all_reduce_base前,需通过torch_npu.npu_trans_quant_param接口对dequant_scale进行处理为INT64类型(处理方法见对应的接口使用说明)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcom_info = default_pg.get_hccl_comm_name(rank) input_ = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum') print("output: ", output) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp import numpy as np class MM_ALLREDUCE_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale): output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1, x2=x2, hcom=hcom, reduce_op=reduce_op, bias=bias, antiquant_scale=antiquant_scale, antiquant_offset=antiquant_offset, x3=x3, dequant_scale=dequant_scale ) return output_npu class MM_ALLREDUCE_A8W8_GRAPH_Model(MM_ALLREDUCE_GRAPH_Model): def __init__(self): super().__init__() def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale): output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1, x2=x2.t(), hcom=hcom, reduce_op=reduce_op, bias=bias, antiquant_scale=antiquant_scale, antiquant_offset=antiquant_offset, x3=x3, dequant_scale=dequant_scale ) return output_npu def define_model(model, graph_type): import torchair if graph_type == 1: # 传统入图模式,静态shape+在线编译场景 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=False) elif graph_type == 2: # ACLNN入图模式,动态shape+二进制 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=True) else: print("Error type") return model def get_graph(input, weight, hcomm_info, dequant_scale, bias, antiquant_scale, antiquant_offset, x3): model = MM_ALLREDUCE_A8W8_GRAPH_Model() model = define_model(model, 2) # 1:静态入图;2:动态入图; output = model(x1=input, x2=weight, hcom=hcomm_info, reduce_op="sum", bias=bias, antiquant_scale=antiquant_scale, antiquant_offset=antiquant_offset, x3=x3, dequant_scale=dequant_scale) return output def run_mc2_a16w16(x1_shape, x2_shape, hcom_info): np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16) np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.float16) input = torch.tensor(np_input).npu() weight = torch.tensor(np_weight).npu() output_a16w16 = get_graph(input, weight, hcom_info, None, None, None, None, None) return output_a16w16 def run_mc2_a8w8(x1_shape, x2_shape, hcom_info): np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.int8) np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8) input = torch.tensor(np_input).npu() weight = torch.tensor(np_weight).npu() weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29) dequant_scale = torch.randn(x2_shape[0], dtype=torch.float32).uniform_(float(-10), float(10)).npu() dequant_scale = torch_npu.npu_trans_quant_param(dequant_scale) output_a8w8 = get_graph(input, weight_nz, hcom_info, dequant_scale, None, None, None, None) return output_a8w8 def run_mc2_a16w8(x1_shape, x2_shape, hcom_info): np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16) np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8) input = torch.tensor(np_input).npu() weight = torch.tensor(np_weight).npu() weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29) antiquant_scale = torch.randn(x2_shape[0], dtype=torch.float16).uniform_(float(-1), float(1)).npu() antiquant_offset = torch.ones(x2_shape[0], dtype=torch.float16).npu() output_a16w8 = get_graph(input, weight_nz, hcom_info, None, None, antiquant_scale, antiquant_offset, None) return output_a16w8 def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, op_type): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcom_info = default_pg.get_hccl_comm_name(rank) output = None # 非量化调用 if op_type == "a16w16": output = run_mc2_a16w16(x1_shape, x2_shape, hcom_info) # 伪量化调用 if op_type == "a16w8": output = run_mc2_a16w8(x1_shape, x2_shape, hcom_info) # 全量化调用 if op_type == "a8w8": output = run_mc2_a8w8(x1_shape, x2_shape, hcom_info) print("output:", output) if __name__ == "__main__": worksize = 2 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [1280, 5120] x2_shape = [640, 5120] op_type = "a16w8" # Options: a16w16, a16w8, a8w8 mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, op_type), nprocs=worksize) |