torch_npu.npu_trans_quant_param
功能描述
完成量化计算参数scale数据类型的转换。
接口原型
npu_trans_quant_param(Tensor scale, Tensor? offset=None) -> Tensor
参数说明
- scale(计算输入):Device侧的Tensor类型,数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,)或者2维(1,n)。其中t = 1或n,其中n与x2的n一致。
- offset( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,)或者2维(1,n)。t = 1或n,其中n与x2的n一致。
输出说明
一个Tensor类型的输出,代表trans_quant_param的计算结果。
约束说明
- 传入的scale或out不能为空。
- scale、offset或out的数据类型和数据格式需要在支持的范围之内。
- scale、offset的shape需要为1维(t,)或者2维(1,n)。其中t = 1或n,其中n与x2的n一致。
- 当scale的shape为两维(1, n)时,scale和offset的shape需要保持一致,且输出shape也为(1, n)。
支持的PyTorch版本
- PyTorch 2.2
- PyTorch 2.1
- PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品
调用示例
单算子调用: import torch import torch_npu import logging import os scale = torch.randn(16, dtype=torch.float32) offset = torch.randn(16, dtype=torch.float32) npu_out = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu()) 图模式: 说明:图模式下,npu_trans_quant_param计算出的结果tensor为uint64数据类型。由于torch不支持该数据类型,需要搭配其他接口使用,如示例代码中的npu_quant_matmul。 import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias): scale_1 = torch_npu.npu_trans_quant_param(scale, offset) return torch_npu.npu_quant_matmul(x1, x2, scale_1, offset, bias) cpu_model = MyModel() model = cpu_model.npu() cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8) cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8) scale = torch.randn(1, dtype=torch.float32) offset = torch.randn(1, dtype=torch.float32) bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32) model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset.npu(), bias.npu())
父主题: torch_npu