文档
注册

torch_npu.npu_quant_scatter_

功能描述

先将updates进行量化,然后将updates中的值按指定的轴axis和索引indices更新self中的值,self中的数据被改变。

接口原型

torch_npu.npu_quant_scatter_(Tensor(a!) self, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=0, int quant_axis=1, str reduce='update') -> Tensor(a!)

参数说明

  • self:Device侧的Tensor类型,必选输入,源数据张量,数据类型支持INT8,数据格式支持ND,支持非连续的Tensor。
  • indices:Device侧的Tensor类型,必选输入,索引张量,数据类型支持INT32,数据格式支持ND,支持非连续的Tensor。
  • updates:Device侧的Tensor类型,必选输入,更新数据张量,数据类型支持BFLOAT16,数据格式支持ND,支持非连续的Tensor。
  • quant_scales:Device侧的Tensor类型,必选输入,量化缩放张量,数据类型支持BFLOAT16,数据格式支持ND,支持非连续的Tensor。
  • quant_zero_points:Device侧的Tensor类型,可选输入,量化偏移张量,数据类型支持BFLOAT16,数据格式支持ND,支持非连续的Tensor。
  • axis:Host侧的int类型,可选参数,updates上用来更新的轴。
  • quant_axis:Host侧的int类型,可选参数,updates上用来量化的轴。
  • reduce:Host侧的str类型,可选参数,表示数据操作方式。

输出说明

返回被更新后的self。

约束说明

  • self的维数只能是3~8维。
  • indices的维数只能是1维或者2维;如果是2维,其第2维的大小必须是2;不支持索引越界,索引越界不校验;indices映射的self数据段不能重合,若重合则会因为多核并发原因导致多次执行结果不一样。
  • updates的维数需要与self的维数一样;其第1维的大小等于indices的第1维的大小,且不大于self的第1维的大小;其axis轴的大小不大于self的axis轴的大小;其余维度的大小要跟self对应维度的大小相等;其最后一维的大小必须32B对齐。
  • quant_scales的元素个数需要等于updates在quant_axis轴的大小。
  • quant_zero_points的元素个数需要等于updates在quant_axis轴的大小。
  • axis不能为updates的第1维或最后1维。
  • quant_axis只能为updates的最后1维。
  • reduce当前只支持‘update’,即更新操作。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu
import numpy as np

data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8)
var = torch.from_numpy(data_var).to(torch.int8).npu()

data_indices = np.random.uniform(0, 1, [24]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()

data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()

data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()

data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu()

axis = -2
quant_axis = -1
reduce = "update"

torch_npu.npu_quant_scatter_(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词