下载
中文
注册

torch_npu.npu_convert_weight_to_int4pack

功能描述

将数据类型为int32的输入tensor打包为int4存放,每8个int4数据通过一个int32数据承载,并进行交叠排放。

接口原型

torch_npu.npu_convert_weight_to_int4pack(Tensor weight, int inner_k_tiles=0) -> Tensor

参数说明

  • weight : Device侧Tensor类型,输入的weight,数据格式支持ND、FRACTAL_NZ,数据类型支持INT32, 不支持非连续的Tensor;维度支持2维,shape支持(k, n), (n, k),最后一维度需要8个元素对齐,元素的值需要在int4的表示范围内,即[-8, 7]。
  • inner_k_tiles:int类型,用于制定内部打包格式中,多少个K-tiles被打包在一起,默认值为0。预留参数,暂未使用

输出说明

输出为Tensor类型,代表int4打包后的输出,数据类型为INT32,shape为(k, n/8), (n, k/8),数据格式支持ND。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas A3 训练系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import torch
    import torch_npu
    
    m = 128
    k = 64
    n = 32
    trans_weight = False
    
    cpu_x = torch.randn((m, k), dtype=torch.float16)
    if trans_weight:
        cpu_weight = torch.randint(low=-8, high=8, size=(n, k), dtype=torch.int32)
        cpu_antiquantscale = torch.randn((n, 1), dtype=torch.float16)
        cpu_antiquantoffset = torch.randn((n, 1), dtype=torch.float16)
    else:
        cpu_weight = torch.randint(low=-8, high=8, size=(k, n), dtype=torch.int32)
        cpu_antiquantscale = torch.randn((1, n), dtype=torch.float16)
        cpu_antiquantoffset = torch.randn((1, n), dtype=torch.float16)
    
    weight_int4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu())
    
    if trans_weight:
        cpu_weight = cpu_weight.transpose(-1, -2)
        weight_int4 = weight_int4.transpose(-1, -2)
        cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
        cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
    
    npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), weight_int4.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    import torch
    import torch_npu
    import  torchair
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    npu_backend = torchair.get_npu_backend(compiler_config=config)
    
    m = 16
    k = 17
    n = 72
    
    trans_weight = False
    is_weight_nz = False
    
    cpu_x = torch.randn((m, k),dtype=torch.float16)
    if trans_weight:
        cpu_weight = torch.randint(low=-8, high=8, size=(n, k) ,dtype=torch.int32)
        cpu_antiquantscale = torch.ones((n, 1),dtype=torch.float16)
        cpu_antiquantoffset = torch.zeros((n, 1),dtype=torch.float16)
    else:
        cpu_weight = torch.randint(low=-8, high=8, size=(k, n) ,dtype=torch.int32)
        cpu_antiquantscale = torch.ones((1, n),dtype=torch.float16)
        cpu_antiquantoffset = torch.zeros((1, n),dtype=torch.float16)
    
    npu_weight = cpu_weight.npu()
    if is_weight_nz:
       # nd to fractal_nz
       npu_weight = torch_npu.npu_format_cast(npu_weight.npu(), 29)
    # int32 to int4pack
    weight_int4pack = torch_npu.npu_convert_weight_to_int4pack(npu_weight)
    
    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
            if trans_weight:
                weight  = weight.transpose(-1, -2)
                antiquant_scale = antiquant_scale.transpose(-1, -2)
                antiquant_offset = antiquant_offset.transpose(-1, -2)
            return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
    
    cpu_model = MyModel()
    model = cpu_model.npu()
    model = torch.compile(cpu_model, backend=npu_backend, dynamic=True, fullgraph=True)
    
    npu_out = model(cpu_x.npu(), weight_int4pack, cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)