下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

torch_npu.npu_quantize

功能描述

  • 算子功能:对输入的张量进行量化处理。
  • 计算公式:
    • 如果div_mode为True:

    • 如果div_mode为False:

接口原型

npu_quantize(Tensor self, Tensor scales, Tensor? zero_points, ScalarType dtype, int axis=1, bool div_mode=True) -> Tensor

参数说明

  • self:Device侧的Tensor类型,需要进行量化的源数据张量,必选输入,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。div_mode为False且dtype为torch.quint4x2时,最后一维需要能被8整除。
  • scales:Device侧的Tensor类型,对self进行scales的张量,必选输入:
    • div_mode为True时,数据类型支持FLOAT、BFLOAT16。
    • div_mode为False时,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与self中第axis维相等或等于1;多维时,scales的shape需要与self的shape维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和self对应的维度相等或等于1)。
  • zero_points:Device侧的Tensor类型,对self进行offset的张量,可选输入:
    • div_mode为True时,数据类型支持INT8、UINT8、INT32、BFLOAT16。
    • div_mode为False时,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与self中第axis维相等或等于1;多维时,scales的shape需要与self维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和self对应的维度相等)。zero_points的shape和dtype需要和scales一致。
  • dtype:指定Device侧输出Tensor的类型:
    • div_mode为True时,格式支持torch.qint8、torch.quint8、torch.int32。
    • div_mode为False时,格式支持torch.qint8、torch.quint4x2。如果dtype为torch.quint4x2时,输出tensor类型为int32,由8个int4拼接。
  • axis:量化的elemwise轴, 其他的轴做broadcast,默认值为1。
    • div_mode为False时,axis取值范围是[-2, +∞)且指定的轴不能超过输入self的维度数。如果axis=-2,代表量化的elemwise轴是输入self的倒数第二根轴;如果axis大于-2,量化的elemwise轴是输入的最后一根轴。
  • div_mode:div_mode为True时,表示用除法计算scales;div_mode为False时,表示用乘法计算scales,默认值为True。

输出说明

y:Device侧的aclTensor,公式中的输出,输出大小与self一致。如果参数dtype为torch.quint4x2,输出的dtype是torch.int32,shape的最后一维是输入shape最后一维的1/8,shape其他维度和输入一致。

约束说明

  • 该融合算子仅在推理场景使用。
  • BFLOAT16数据类型仅在Atlas A2训练系列产品/Atlas 800I A2推理产品支持。
  • div_mode为False时,支持Atlas 推理系列产品,但是如下场景仅在Atlas A2训练系列产品/Atlas 800I A2推理产品支持:dtype为torch.quint4x2的场景;axis为-2的场景。

支持的型号

  • Atlas A2训练系列产品/Atlas 800I A2推理产品
  • Atlas 推理系列产品

调用示例

import torch
import torch_npu
x = torch.randn(1, 1, 12).bfloat16().npu()
scale = torch.tensor([0.1] * 12).bfloat16().npu()
out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False)
print(out)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词