下载
中文
注册

torch_npu.npu_group_quant

功能描述

  • 算子功能:对输入的张量进行分组量化操作。
  • 计算公式为:

接口原型

1
torch_npu.npu_group_quant(Tensor x, Tensor scale, Tensor group_index, *, Tensor? offset=None, ScalarType? dst_dtype=None) -> Tensor

参数说明

  • x:Device侧的Tensor类型,需要做量化的源数据张量,必选输入。数据类型支持FLOAT32、FLOAT16、BFLOAT16。数据格式支持ND。x为2维张量,如果dst_dtype为torch.quint4x2,Shape的最后一维需要能被8整除。
  • scale:Device侧的Tensor类型,量化中的scale值,必选输入。数据类型支持FLOAT32、FLOAT16、BFLOAT16。数据格式支持ND。scale为2维张量,第0维大小不支持为0,并且scale的第1维与x的第1维相等。
  • group_index:Device侧的Tensor类型,分组量化中的group编号值,必选输入。数据类型支持INT32、INT64。数据格式支持ND。group_index为1维张量,并且group_index的第0维与scale的第0维相等。
  • offset:Device侧的Tensor类型,量化中的offset值,可选输入。数据类型支持FLOAT32、FLOAT16、BFLOAT16,且数据类型与scale的数据类型一致。数据格式支持ND。offset为一个数。
  • dst_dtype:ScalarType类型,可选参数,输入值允许为torch.int8或torch.quint4x2,默认值为torch.int8。

输出说明

一个Tensor类型的输出,代表group_quant的计算结果。如果参数dst_dtype为torch.int8,输出大小与输入x的大小一致。如果参数dst_dtype为torch.quint4x2,输出的数据类型是torch.int32,Shape的第0维大小与输入x的第0维大小一致,最后一维是输入x的最后一维的1/8。

约束说明

  • 如果属性dst_dtype为torch.quint4x2,则输入x的Shape的最后一维需要能被8整除。
  • 输入group_index必须是非递减序列,最小值不能小于0,最大值必须与输入x的shape的第0维大小相等。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas A3 训练系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    import torch
    import torch_npu
    
    x = torch.randn(6, 4).to(torch.float16).npu()
    scale = torch.randn(4, 4).to(torch.float32).npu()
    group_index = torch.tensor([1, 4, 6, 6], dtype=torch.int32).npu()
    offset = torch.randn(1).to(torch.float32).npu()
    
    y = torch_npu.npu_group_quant(x, scale, group_index, offset=offset, dst_dtype=torch.qint8)
    print(y)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import torch
    import torch_npu
    import torchair as tng
    from torchair.ge_concrete_graph import ge_apis as ge
    from torchair.configs.compiler_config import CompilerConfig
    attr_dst_type = 2
    attr_dst_type_torch = torch.qint8 if attr_dst_type == 2 else torch.quint4x2
    x= torch.randn(6, 4).to(torch.float16).npu()
    scale = torch.randn(4, 4).to(torch.float32).npu()
    group_index = torch.tensor([1, 4, 6, 6], dtype=torch.int32).npu()
    offset = torch.randn(1).to(torch.float32).npu()
    class Network(torch.nn.Module):
        def __init__(self):
            super(Network, self).__init__()
        def forward(self, x, scale, group_index, offset, dst_type):
            return torch_npu.npu_group_quant(x, scale, group_index, offset=offset, dst_dtype=dst_type)
    model = Network()
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    config.debug.graph_dump.type = 'pbtxt'
    model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
    output_data = model(x, scale, group_index, offset=offset, dst_type=attr_dst_type_torch)