先将updates进行量化,然后将updates中的值按指定的轴axis和索引indices更新input中的值,input中的数据被改变。
torch_npu.npu_quant_scatter_(Tensor(a!) input, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=0, int quant_axis=1, str reduce='update') -> Tensor(a!)
返回被更新后的input。
Atlas A2 训练系列产品
import torch import torch_npu import numpy as np data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8) var = torch.from_numpy(data_var).to(torch.int8).npu() data_indices = np.random.uniform(0, 1, [24]).astype(np.int32) indices = torch.from_numpy(data_indices).to(torch.int32).npu() data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16) updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu() data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16) quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu() data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16) quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu() axis = -2 quant_axis = -1 reduce = "update" torch_npu.npu_quant_scatter_(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce)