torch_npu.npu_quantize
功能描述
- 算子功能:对输入的张量进行量化处理。
- 计算公式:
- 如果div_mode为True:
- 如果div_mode为False:
- 如果div_mode为True:
接口原型
1 | npu_quantize(Tensor self, Tensor scales, Tensor? zero_points, ScalarType dtype, int axis=1, bool div_mode=True) -> Tensor |
参数说明
- self:Device侧的Tensor类型,需要进行量化的源数据张量,必选输入,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。div_mode为False且dtype为torch.quint4x2时,最后一维需要能被8整除。
- scales:Device侧的Tensor类型,对self进行scales的张量,必选输入:
- div_mode为True时
Atlas 推理系列产品 :数据类型支持FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT、BFLOAT16。
- div_mode为False时,数据格式支持ND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与self中第axis维相等或等于1;多维时,scales的shape需要与self的shape维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和self对应的维度相等或等于1)。
Atlas 推理系列产品 :数据类型支持FLOAT、FLOAT16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT、FLOAT16、BFLOAT16。
- div_mode为True时
- zero_points:Device侧的Tensor类型,对self进行offset的张量,可选输入:
- div_mode为True时
Atlas 推理系列产品 :数据类型支持INT8、UINT8、INT32。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持INT8、UINT8、INT32、BFLOAT16。
- div_mode为False时,数据格式支持ND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与self中第axis维相等或等于1;多维时,scales的shape需要与self维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和self对应的维度相等)。zero_points的shape和dtype需要和scales一致。
Atlas 推理系列产品 :数据类型支持FLOAT、FLOAT16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT、FLOAT16、BFLOAT16。
- div_mode为True时
- dtype:指定Device侧输出Tensor的类型:
- div_mode为True时,格式支持torch.qint8、torch.quint8、torch.int32。
- div_mode为False时,格式支持torch.qint8、torch.quint4x2。如果dtype为torch.quint4x2时,输出tensor类型为int32,由8个int4拼接。
- axis:量化的elemwise轴, 其他的轴做broadcast,默认值为1。
- div_mode为False时,axis取值范围是[-2, +∞)且指定的轴不能超过输入self的维度数。如果axis=-2,代表量化的elemwise轴是输入self的倒数第二根轴;如果axis大于-2,量化的elemwise轴是输入的最后一根轴。
- div_mode:div_mode为True时,表示用除法计算scales;div_mode为False时,表示用乘法计算scales,默认值为True。
输出说明
y:Device侧的aclTensor,公式中的输出,输出大小与self一致。如果参数dtype为torch.quint4x2,输出的dtype是torch.int32,shape的最后一维是输入shape最后一维的1/8,shape其他维度和输入一致。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- div_mode为False时:
- 如下场景,不支持
Atlas 推理系列产品 。- dtype为torch.quint4x2的场景。
- axis为-2的场景。
- 支持
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 。
- 如下场景,不支持
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例
- 单算子模式调用(
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 )1 2 3 4 5 6
import torch import torch_npu x = torch.randn(1, 1, 12).bfloat16().npu() scale = torch.tensor([0.1] * 12).bfloat16().npu() out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False) print(out)
- 单算子模式调用(
Atlas 推理系列产品 )1 2 3 4 5 6
import torch import torch_npu x = torch.randn((2, 3, 12), dtype=torch.float).npu() scale = torch.tensor(([3] * 12),dtype=torch.float).npu() out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False) print(out)
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig x = torch.randn((2, 3, 12), dtype=torch.float16).npu() scale = torch.tensor(([3] * 12),dtype=torch.float16).npu() axis =1 div_mode = False class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() def forward(self, x, scale,zero_points, dst_type,div_mode): return torch_npu.npu_quantize(x, scale,zero_points=zero_points,dtype=dst_type,div_mode=div_mode) model = Network() config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) config.debug.graph_dump.type = 'pbtxt' model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True) output_data = model(x, scale,None,dst_type=torch.qint8, div_mode=div_mode) print(output_data)
父主题: torch_npu