下载
中文
注册

torch_npu.npu_anti_quant

功能描述

  • 算子功能:将INT4或者INT8数据反量化为FP16或者BF16,其中输入是INT4类型时,将每8个数据看作是一个INT32数据。
  • 计算公式为:

接口原型

1
torch_npu.npu_anti_quant(Tensor x, Tensor scale, *, Tensor? offset=None, ScalarType? dst_dtype=None, ScalarType? src_dtype=None) -> Tensor

参数说明

  • x:Tensor类型,即输入参数中的x。数据格式支持ND,支持非连续的Tensor。输入最大支持8维。
    • Atlas 推理系列产品数据类型支持INT8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品数据类型支持INT8、INT32,其中INT32类型数据的每个值是由8个INT4数值拼成的。
    • Atlas A3 训练系列产品数据类型支持INT8、INT32,其中INT32类型数据的每个值是由8个INT4数值拼成的。
  • scale:Tensor类型,数据类型支持FLOAT32、BFLOAT16,数据格式支持ND,支持非连续的Tensor,仅支持1维Tensor,shape为(n,),其中n可以为1,如果n不为1,则在输入x为INT8类型时,必须与输入x的尾轴维度的大小相同,在输入x为INT32类型时,必须为输入x的尾轴维度大小的8倍。
  • offset:Tensor类型,可选参数,数据类型支持FLOAT32、BFLOAT16,且数据类型与scale的数据类型一致。数据格式支持ND,支持非连续的Tensor,仅支持1维Tensor,且shape必须与scale的shape大小一致,数据类型必须与scale的数据类型一致。
  • dst_dtype:ScalarType类型,可选参数,输入值允许为torch.float16或torch.bfloat16,默认值为torch.float16。
  • src_dtype:ScalarType类型,可选参数。
    • Atlas 推理系列产品输入值允许为torch.int8,默认为torch.int8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品输入值允许为torch.quint4x2或torch.int8,默认为torch.int8。
    • Atlas A3 训练系列产品输入值允许为torch.quint4x2或torch.int8,默认为torch.int8。

输出说明

一个Tensor类型的输出,代表antiquant的计算结果。

约束说明

  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • x、scale这两个输入中不能含有空指针。
  • 如果输入scale的shape值不为1,则输入x的最后一维shape值必须与scale的shape一致。
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品数据类型支持BFLOAT16,输入x支持数据类型为INT32。
  • Atlas A3 训练系列产品数据类型支持BFLOAT16,输入x支持数据类型为INT32。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas A3 训练系列产品
  • Atlas 推理系列产品

调用示例

  • 单算子模式调用
    1
    2
    3
    4
    5
    6
    import torch
    import torch_npu
    x_tensor = torch.tensor([1,2,3,4], dtype=torch.int8).npu()
    scale = torch.tensor([2.0], dtype=torch.float).npu()
    offset = torch.tensor([2.0], dtype=torch.float).npu()
    out = torch_npu.npu_anti_quant(x_tensor, scale, offset=offset, dst_dtype=torch.float16)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import torch
    import torch_npu
    import torchair as tng
    from torchair.ge_concrete_graph import ge_apis as ge
    from torchair.configs.compiler_config import CompilerConfig
    
    config = CompilerConfig()
    config.debug.graph_dump.type = 'pbtxt'
    npu_backend = tng.get_npu_backend(compiler_config=config)
    x_tensor = torch.tensor([1,2,3,4], dtype=torch.int8).npu()
    scale = torch.tensor([2.0], dtype=torch.float).npu()
    offset = torch.tensor([2.0], dtype=torch.float).npu()
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self,x,scale,offset):
            return torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=torch.float16)
    cpu_model = Model()
    model = torch.compile(cpu_model, backend=npu_backend, dynamic=False, fullgraph=True)
    output = model(x_tensor,scale,offset)