先将updates进行量化,然后将updates中的值按指定的轴axis和索引indices更新input中的值,并将结果保存到输出tensor,input本身的数据不变。
torch_npu.npu_quant_scatter(Tensor input, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=0, int quant_axis=1, str reduce='update') -> Tensor
一个Tensor类型的输出,代表input被更新后的结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch import torch_npu import numpy as np data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8) var = torch.from_numpy(data_var).to(torch.int8).npu() data_indices = np.random.uniform(0, 1, [24]).astype(np.int32) indices = torch.from_numpy(data_indices).to(torch.int32).npu() data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16) updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu() data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16) quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu() data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16) quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu() axis = -2 quant_axis = -1 reduce = "update" out = torch_npu.npu_quant_scatter(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | # 入图方式 import torch import torch_npu import math import torchair as tng import numpy as np from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 dtype_list2 =["fp16","int8","int32","fp32","fp16"] dtype_list =[np.float16,np.int8,np.int32,np.float32] updates_shape =[1,11,1,32] var_shape =[1,11,1,32] indices_shape =[1,2] quant_scales_shape =[1,1,1,32] quant_zero_points_shape =[1,1,1,32] axis =-2 quant_axis=-1 reduce = "update" updates_data = np.random.uniform(-1,1,size=updates_shape) var_data = np.random.uniform(1,2,size=var_shape).astype(dtype_list[1]) quant_scales_data = np.random.uniform(1,2,size=quant_scales_shape) indices_data = np.random.uniform(0,1,size=indices_shape).astype(dtype_list[2]) quant_zero_points_data = np.random.uniform(0,1,size=quant_zero_points_shape) updates_npu = torch.from_numpy(updates_data).npu().to(torch.bfloat16).npu() var_npu = torch.from_numpy(var_data).npu() quant_scales_npu = torch.from_numpy(quant_scales_data).npu().to(torch.bfloat16).npu() quant_zero_points_npu = torch.from_numpy(quant_zero_points_data).to(torch.bfloat16).npu() indices_npu = torch.from_numpy(indices_data).npu() class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce) print("single op output with mask:", single_op[0], single_op[0].shape) print("graph output with mask:", graph_output[0], graph_output[0].shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]], [[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]], [[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1, 0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]], [[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0, 1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]], [[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]], [[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]], [[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]], [[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]], [[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]], [[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]], [[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1, 1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]], device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32]) graph output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]], [[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]], [[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1, 0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]], [[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0, 1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]], [[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]], [[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]], [[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]], [[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]], [[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]], [[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]], [[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1, 1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]], device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32]) |