下载
中文
注册

torch_npu.npu_quant_matmul

功能描述

完成量化的矩阵乘计算,最小支持输入维度为2维,最大支持输入维度为6维。

接口原型

1
npu_quant_matmul(Tensor x1, Tensor x2, Tensor scale, *, Tensor? offset=None, Tensor? pertoken_scale=None, Tensor? bias=None, ScalarType? output_dtype=None) -> Tensor

参数说明

  • x1(计算输入):Device侧的Tensor类型,数据类型支持INT8和INT32,其中INT32类型表示使用本接口进行INT4类型矩阵乘计算,INT32类型承载的是INT4数据,每个INT32数据存放8个INT4数据。数据格式支持ND,shape需要在2-6维范围。
  • x2(计算输入):Device侧的Tensor类型(weight),数据类型支持INT8和INT32(INT32类型含义同x1,表示INT4的数据计算),与x1的数据类型须保持一致。数据格式支持ND,shape需要在2-6维范围。
  • scale(计算输入):Device侧的Tensor类型,数据格式支持ND,shape需要是1维(t, ),t=1或n,其中n与x2的n一致。如需传入INT64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param来获取INT64数据类型的scale。
    • Atlas 推理系列加速卡产品:数据类型支持FLOAT32、INT64。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FLOAT32、INT64、BFLOAT16。
    • Atlas A3 训练系列产品:数据类型支持FLOAT32、INT64、BFLOAT16。
  • offset( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape需要是1维(t,),t=1或n,其中n与x2的n一致。
  • pertoken_scale(计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape需要是1维(m,),其中m与x1的m一致。
  • bias( 计算输入):Device侧的Tensor类型,可选参数,数据格式支持ND,shape支持1维(n,)或3维(batch,1,n),n与x2的n一致,同时batch值需要等于x1和x2 boardcast后推导出的batch值。当输出是2、4、5、6维情况下,bias的shape必须为1维。当输出是3维情况下,bias的shape可以为1维或3维。
    • Atlas 推理系列加速卡产品:数据类型支持INT32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持INT32、BFLOAT16、FLOAT16、FLOAT32。
    • Atlas A3 训练系列产品:数据类型支持INT32、BFLOAT16、FLOAT16、FLOAT32。
  • output_dtype( 计算输入):Device侧的ScalarType类型,可选参数。表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为INT8。
    • Atlas 推理系列加速卡产品:支持输入torch.int8、torch.float16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:支持输入torch.int8、torch.float16、torch.bfloat16。
    • Atlas A3 训练系列产品:支持输入torch.int8、torch.float16、torch.bfloat16。

输出说明

一个Tensor类型的输出,代表量化matmul的计算结果。

  • 如果output_dtype为torch.float16,输出的数据类型为FLOAT16。
  • 如果output_dtype为torch.bfloat16,输出的数据类型为BFLOAT16。
  • 如果output_dtype为torch.int8或者None,输出的数据类型为INT8。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • 传入的x1、x2、scale不能是空。
  • x1、x2、bias、scale、offset、pertoken_scale、output_dtype的数据类型和数据格式需要在支持的范围之内。
  • x1与x2最后一维的shape大小不能超过65535。
  • 目前输出INT8/FLOAT16且无pertoken_scale情况下,图模式不支持scale直接传入FLOAT32数据类型。
  • 如果在PyTorch图模式中使用本接口,且ENABLE_ACLNN =false,则在调用接口前需要对shape为(n,k//8)的x2数据进行转置,转置过程应写在图中。
  • 支持x2为昇腾亲和的数据排布以提高搬运效率。需要调用torch_npu.npu_format_cast完成输入x2(weight)为昇腾亲和的数据排布功能。
    • Atlas 推理系列加速卡产品:必须将x2转置后转format。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:推荐将x2非转置后转format。
    • Atlas A3 训练系列产品:推荐将x2非转置后转format。
  • INT4类型计算的额外约束:
    • x1,x2的数据类型均为INT32,每个INT32类型的数据存放8个INT4数据。输入的INT32shape需要将数据原本INT4类型时shape的最后一维缩小8倍。INT4数据的shape最后一维应为8的倍数。例如:进行(m,k)乘(k,n)的INT4类型矩阵乘计算时,需要输入INT32类型,shape为(m,k//8)(k,n//8)的数据,其中k与n都应是8的倍数。
    • x1只能接受shape为(m,k//8)且数据排布连续的数据,x2可以接受(k,n[g1] //8)且数据排布连续的数据或shape为(k//8,n)且是由数据连续排布的(n,k//8)转置而来的数据。数据排布连续指数组中所有相邻的数,包括换行时内存地址连续。使用Tensor.is_contiguous返回值为true则表明tensor数据排布连续。
  • 输入参数间支持的数据类型组合情况如下:
    表1 Atlas 推理系列产品

    x1

    x2

    scale

    offset

    bias

    pertoken_scale

    output_dtype

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    表2 Atlas A2 训练系列产品/Atlas 800I A2 推理产品Atlas A3 训练系列产品

    x1

    x2

    scale

    offset

    bias

    pertoken_scale

    output_dtype

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    int8

    int8

    float32/bfloat16

    None

    int32/bfloat16/float32/None

    float32/None

    bfloat16

    int8

    int8

    float32

    None

    int32/float16/float32/None

    float32

    float16

    int32

    int32

    int64/float32

    None

    int32/None

    None

    float16

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas 推理系列加速卡产品
  • Atlas A3 训练系列产品

调用示例

  • 单算子调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    import torch
    import torch_npu
    import logging
    import os
    
    cpu_x1 = torch.randint(-5, 5, (1, 256, 768), dtype=torch.int8)
    cpu_x2 = torch.randint(-5, 5, (31, 768, 16), dtype=torch.int8)
    scale = torch.randn(16, dtype=torch.float32)
    offset = torch.randn(16, dtype=torch.float32)
    bias = torch.randint(-5, 5, (31, 1, 16), dtype=torch.int32)
    # Method 1:You can directly call npu_quant_matmul
    npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset=offset.npu(), bias=bias.npu())
    
    # Method 2: You can first call npu_trans_quant_param to convert scale and offset from float32 to int64 when output dtype is torch.int8 or torch.float16
    scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu())
    npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale_1,  bias=bias.npu())
    
  • 图模式调用

    输出int8/fp16且无pertoken情况下,必须先调用npu_trans_quant_param。

    • 通用
      • 输出float16
         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        import torch
        import torch_npu
        import torchair as tng
        from torchair.ge_concrete_graph import ge_apis as ge
        from torchair.configs.compiler_config import CompilerConfig
        import logging
        from torchair.core.utils import logger
        logger.setLevel(logging.DEBUG)
        import os
        import numpy as np
        # "ENABLE_ACLNN"是否使能走aclnn, true: 回调走aclnn, false: 在线编译
        os.environ["ENABLE_ACLNN"] = "true"
        config = CompilerConfig()
        npu_backend = tng.get_npu_backend(compiler_config=config)
        
        class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
            def forward(self, x1, x2, scale, offset, bias):
                return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, output_dtype=torch.float16)
        cpu_model = MyModel()
        model = cpu_model.npu()
        cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8)
        cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8)
        scale = torch.randn(1, dtype=torch.float32)  
        # pertoken_scale为空时,输出fp16必须先调用npu_trans_quant_param,将scale(offset)从float转为int64.
        scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), None)
        bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32)
        # dynamic=True: 动态图模式, dynamic=False: 静态图模式
        model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
        npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale_1, None, bias.npu())
        
      • 输出bfloat16
         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        import torch
        import torch_npu
        import torchair as tng
        from torchair.ge_concrete_graph import ge_apis as ge
        from torchair.configs.compiler_config import CompilerConfig
        import logging
        from torchair.core.utils import logger
        logger.setLevel(logging.DEBUG)
        import os
        import numpy as np
        os.environ["ENABLE_ACLNN"] = "true"
        config = CompilerConfig()
        npu_backend = tng.get_npu_backend(compiler_config=config)
        
        class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
            def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
                return torch_npu.npu_quant_matmul(x1, x2.t(), scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale,output_dtype=torch.bfloat16)
        cpu_model = MyModel()
        model = cpu_model.npu()
        m = 15
        k = 11264
        n = 6912
        bias_flag = True
        cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
        cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
        scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
        
        bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
        if bias_flag:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
        else:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu())
        
    • 高性能数据排布调用方式
      • Atlas 推理系列加速卡产品将x2转置(batch,n,k)后转format
         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        import torch
        import torch_npu
        import torchair as tng
        from torchair.ge_concrete_graph import ge_apis as ge
        from torchair.configs.compiler_config import CompilerConfig
        import logging
        from torchair.core.utils import logger
        logger.setLevel(logging.DEBUG)
        import os
        import numpy as np
        os.environ["ENABLE_ACLNN"] = "true"
        config = CompilerConfig()
        npu_backend = tng.get_npu_backend(compiler_config=config)
        
        class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
            def forward(self, x1, x2, scale, offset, bias):
                return torch_npu.npu_quant_matmul(x1, x2.transpose(2,1), scale, offset=offset, bias=bias)
        cpu_model = MyModel()
        model = cpu_model.npu()
        cpu_x1 = torch.randint(-1, 1, (15,1, 512), dtype=torch.int8).npu()
        cpu_x2 = torch.randint(-1, 1, (15,512, 128), dtype=torch.int8).npu()
        # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,n,k) layout
        cpu_x2_t_29 = torch_npu.npu_format_cast(cpu_x2.transpose(2,1).contiguous(), 29)
        scale = torch.randn(1, dtype=torch.float32).npu()
        offset = torch.randn(1, dtype=torch.float32).npu()
        bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
        # Process scale from float32 to int64 offline to improve performance
        scale_1 = torch_npu.npu_trans_quant_param(scale, offset)
        model = torch.compile(cpu_model, backend=npu_backend, dynamic=False)
        npu_out = model(cpu_x1, cpu_x2_t_29, scale_1, offset, bias)
        
      • Atlas A2 训练系列产品将x2非转置(batch,k,n)后转format
         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        37
        import torch
        import torch_npu
        import torchair as tng
        from torchair.ge_concrete_graph import ge_apis as ge
        from torchair.configs.compiler_config import CompilerConfig
        import logging
        from torchair.core.utils import logger
        logger.setLevel(logging.DEBUG)
        import os
        import numpy as np
        config = CompilerConfig()
        npu_backend = tng.get_npu_backend(compiler_config=config)
        
        class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
            def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
                return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale,output_dtype=torch.bfloat16)
        cpu_model = MyModel()
        model = cpu_model.npu()
        m = 15
        k = 11264
        n = 6912
        bias_flag = True
        cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
        cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
        # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,k,n) layout
        x2_notranspose_29 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29)
        scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
        
        bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
        if bias_flag:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
        else:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu())
        
      • Atlas A3 训练系列产品将x2非转置(batch,k,n)后转format
         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        37
        import torch
        import torch_npu
        import torchair as tng
        from torchair.ge_concrete_graph import ge_apis as ge
        from torchair.configs.compiler_config import CompilerConfig
        import logging
        from torchair.core.utils import logger
        logger.setLevel(logging.DEBUG)
        import os
        import numpy as np
        config = CompilerConfig()
        npu_backend = tng.get_npu_backend(compiler_config=config)
        
        class MyModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
            def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
                return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale,output_dtype=torch.bfloat16)
        cpu_model = MyModel()
        model = cpu_model.npu()
        m = 15
        k = 11264
        n = 6912
        bias_flag = True
        cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
        cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
        # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,k,n) layout
        x2_notranspose_29 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29)
        scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
        
        bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
        model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
        if bias_flag:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
        else:
            npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu())