torch_npu.npu_quant_matmul
功能描述
完成量化的矩阵乘计算,最小支持输入维度为2维,最大支持输入维度为6维。
接口原型
npu_quant_matmul(Tensor x1, Tensor x2, Tensor scale, *, Tensor? offset=None, Tensor? pertoken_scale=None, Tensor? bias=None, ScalarType? output_dtype=None) -> Tensor
参数说明
- x1(计算输入):Device侧的Tensor类型,数据类型支持INT8,数据格式支持ND,shape最少是2维,最多是6维。
- x2(计算输入):Device侧的Tensor类型,数据类型支持INT8,数据格式支持ND,shape最少是2维,最多是6维。
- scale(计算输入):Device侧的Tensor类型,数据类型支持FLOAT32、INT64、BFLOAT16。数据格式支持ND,shape是1维(t,),t = 1或n,其中n与x2的n一致。如需传入INT64数据类型的scale, 需要提前调用torch_npu.npu_trans_quant_param接口来获取INT64数据类型的scale。
- offset( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,),t = 1或n,其中n与x2的n一致。
- pertoken_scale(计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(m,),其中m与x1的m一致。Atlas 推理系列加速卡产品当前不支持pertoken_scale。
- bias( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持INT32、BFLOAT16,数据格式支持ND,shape支持1维(n,)或3维(batch,1,n),n与x2的n一致。bias 3维(batch,1,n)只出现在out为3维的场景下,同时batch值需要等于x1, x2 boardcast后推导出的batch值。
- output_dtype( 计算输入):Device侧的ScalarType类型,可选参数。表示输出Tensor的数据类型,支持输入torch.int8,torch.float16,torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列加速卡产品只支持output_dtype为torch.int8(含None,下同)和torch.float16。
输出说明
一个Tensor类型的输出,代表量化matmul的计算结果。
- 如果output_dtype为torch.float16,输出的数据类型为FLOAT16。
- 如果output_dtype为torch.bfloat16,输出的数据类型为BFLOAT16。
- 如果output_dtype为torch.int8或者None,输出的数据类型为INT8。
如果output_dtype非以上数据类型,返回错误码。
约束说明
- 传入的x1、x2、scale不能是空。
- x1、x2、bias、scale、offset、pertoken_scale、output_dtype的数据类型和数据格式需要在支持的范围之内。
- x1、x2的shape需要在2-6维范围。
- x1与x2最后一维的shape大小不能超过65535。
- scale、offset的shape需要为1维(t,),t = 1或n,n与x2的n一致。
- pertoken_scale的shape需要为1维(m, ),m与x1的m一致。Atlas 推理系列加速卡产品当前不支持pertoken_scale。
- bias的shape支持1维(n,)或3维(batch,1,n),n与x2的n一致, batch值需要等于x1, x2 boardcast后推导出的batch值。
- bias的shape在out 是2,4,5,6维情况下需要为1维,在out 是3维情况下可以为1维或3维。
- output_dtype不同取值对scale的约束:
- output_dtype为torch.bfloat16时,scale需要为BFLOAT16数据类型的Tensor。
- output_dtype为torch.float16或torch.int8,并且在pertoken_scale为空时,scale可为FLOAT32或INT64数据类型的Tensor。
- output_dtype为torch.float16且pertoken_scale不为空时,scale必须为FLOAT32。
- bias为BFLOAT16数据类型时,output_dtype需要为torch.bfloat16。
- 目前输INT8/FLOAT16且无pertoken_scale情况下,图模式不支持scale直接传入FLOAT32数据类型。
- pertoken_scale仅支持FLOAT32,目前仅在输出FLOAT16和BFLOAT16场景下可不为空。
- offset不为空时,output_dtype仅支持torch.int8。
- 图模式可支持x2为昇腾亲和的数据排布以提高搬运效率。单算子模式则不支持,请使用图模式获取更优性能。需要调用npu_format_cast可以完成输入x2(weight)为昇腾亲和的数据排布功能。
- Atlas 推理系列加速卡产品必须将x2转置后转format。
- Atlas A2 训练系列产品推荐将x2非转置后转format。
支持的PyTorch版本
- PyTorch 2.3
- PyTorch 2.2
- PyTorch 2.1
- PyTorch 1.11.0
支持的型号
- Atlas A2 训练系列产品
- Atlas 推理系列加速卡产品
调用示例
- 单算子调用
# 在单算子模式下不支持使能X2为昇腾亲和的数据排布,如果想追求极致性能,请使用图模式 import torch import torch_npu import logging import os cpu_x1 = torch.randint(-5, 5, (1, 256, 768), dtype=torch.int8) cpu_x2 = torch.randint(-5, 5, (31, 768, 16), dtype=torch.int8) scale = torch.randn(16, dtype=torch.float32) offset = torch.randn(16, dtype=torch.float32) bias = torch.randint(-5, 5, (31, 1, 16), dtype=torch.int32) # Method 1:You can directly call npu_quant_matmul npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset=offset.npu(), bias=bias.npu()) # Method 2: You can first call npu_trans_quant_param to convert scale and offset from float32 to int64 when output dtype is torch.int8 or torch.float16 scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu()) npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale_1, bias=bias.npu())
- 图模式调用(图模式目前仅支持PyTorch 2.1版本)
输出int8/fp16且无pertoken情况下,必须先调用npu_trans_quant_param。
- 通用
- 示例一:输出float16
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np # "ENABLE_ACLNN"是否使能走aclnn, true: 回调走aclnn, false: 在线编译 os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias): return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, output_dtype=torch.float16) cpu_model = MyModel() model = cpu_model.npu() cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8) cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8) scale = torch.randn(1, dtype=torch.float32) # pertoken_scale为空时,输出fp16必须先调用npu_trans_quant_param,将scale(offset)从float转为int64. scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), None) bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32) # dynamic=True: 动态图模式, dynamic=False: 静态图模式 model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale_1, None, bias.npu())
- 示例2:输出bfloat16
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias, pertoken_scale): return torch_npu.npu_quant_matmul(x1, x2.t(), scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale,output_dtype=torch.bfloat16) cpu_model = MyModel() model = cpu_model.npu() m = 15 k = 11264 n = 6912 bias_flag = True cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8) cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8) scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16) pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32) bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16) model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) if bias_flag: npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, None, pertoken_scale.npu()) else: npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, bias.npu(), pertoken_scale.npu())
- 示例一:输出float16
- 2.2 高性能数据排布调用方式
- Atlas 推理系列加速卡产品将x2转置(batch,n,k)后转format
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias): return torch_npu.npu_quant_matmul(x1, x2.transpose(2,1), scale, offset=offset, bias=bias) cpu_model = MyModel() model = cpu_model.npu() cpu_x1 = torch.randint(-1, 1, (15,1, 512), dtype=torch.int8).npu() cpu_x2 = torch.randint(-1, 1, (15,512, 128), dtype=torch.int8).npu() # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,n,k) layout cpu_x2_t_29 = torch_npu.npu_format_cast(cpu_x2.transpose(2,1).contiguous(), 29) scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(1, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() # Process scale from float32 to int64 offline to improve performance scale_1 = torch_npu.npu_trans_quant_param(scale, offset) model = torch.compile(cpu_model, backend=npu_backend, dynamic=False) npu_out = model(cpu_x1, cpu_x2_t_29, scale_1, offset, bias)
- Atlas A2 训练系列产品将x2非转置(batch,k,n)后转format
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2, scale, offset, bias, pertoken_scale): return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale,output_dtype=torch.bfloat16) cpu_model = MyModel() model = cpu_model.npu() m = 15 k = 11264 n = 6912 bias_flag = True cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8) cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8) # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,k,n) layout x2_notranspose_29 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29) scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16) pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32) bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16) model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) if bias_flag: npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu()) else: npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
- Atlas 推理系列加速卡产品将x2转置(batch,n,k)后转format
- 通用
父主题: torch_npu