文档
注册

torch_npu.npu_quant_matmul

功能描述

完成量化的矩阵乘计算,最小支持输入维度为2维,最大支持输入维度为6维。

接口原型

npu_quant_matmul(Tensor x1, Tensor x2, Tensor scale, *, Tensor? offset=None, Tensor? pertoken_scale=None, Tensor? bias=None, ScalarType? output_dtype=None) -> Tensor

参数说明

  • x1(计算输入):Device侧的Tensor类型,数据类型支持INT8。数据格式支持ND,shape最少是2维,最多是6维。
  • x2(计算输入):Device侧的Tensor类型,数据类型支持INT8。数据格式支持ND,shape最少是2维,最多是6维。
  • scale(计算输入):Device侧的Tensor类型,数据类型支持FLOAT32, INT64, BFLOAT16。数据格式支持ND,shape是1维(t,),t = 1或n,其中n与x2的n一致。如需传入INT64数据类型的scale, 需要提前调用torch_npu.npu_trans_quant_param接口来获取INT64数据类型的scale。
  • offset( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,),t = 1或n,其中n与x2的n一致。
  • pertoken_scale(计算输入):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(m,),其中m与x1的m一致。目前仅为接口预留参数,pertoken功能当前不可用。
  • bias( 计算输入):Device侧的Tensor类型,可选参数。数据类型支持INT32,数据格式支持ND,shape支持1维(n,)或3维(batch,1,n),n与x2的n一致。bias 3维(batch,1,n)只出现在out为3维的场景下,同时batch值需要等于x1, x2 boardcast后推导出的batch值。
  • output_dtype( 计算输入):Device侧的ScalarType类型,可选参数。表示输出Tensor的数据类型,支持输入torch.int8,torch.float16,torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列产品(Ascend 310P处理器)只支持output_dtype为torch.int8(含None,下同)和torch.float16。

输出说明

一个Tensor类型的输出,代表量化matmul的计算结果。如果output_dtype为torch.float16,输出的数据类型为FLOAT16;如果output_dtype为torch.bfloat16,输出的数据类型为BFLOAT16;如果output_dtype为torch.int8或者None,输出的数据类型为INT8;如果output_dtype非以上数据类型,返回错误码。

约束说明

  • 传入的x1、x2、scale不能是空。
  • x1、x2、bias、scale、offset、pertoken_scale、output_dtype的数据类型和数据格式需要在支持的范围之内。
  • x1、x2的shape需要在2-6维范围。
  • scale, offset的shape需要为1维(t,),t = 1或n,n与x2的n一致。
  • pertoken_scale的shape需要为1维(m, ),m与x1的m一致。目前仅为接口预留参数,pertoken功能不可用。
  • bias的shape支持1维(n,)或3维(batch,1,n),n与x2的n一致, batch值需要等于x1, x2 boardcast后推导出的batch值。
  • bias的shape在out 是2,4,5,6维情况下需要为1维,在out 是3维情况下可以为1维或3维。
  • output_dtype为torch.bfloat16时,scale需要为BFLOAT16数据类型的Tensor。output_dtype为torch.float16或torch.int8,并且在pertoken_scale为空时,scale可为FLOAT32或INT64数据类型的Tensor。output_dtype为torch.float16且pertoken_scale不为空时,scale必须为FLOAT32。
  • pertoken_scale仅支持FLOAT32,目前仅为接口预留参数,pertoken功能不可用。
  • offset不为空时,output_dtype仅支持torch.int8。
  • x1与x2最后一维的shape大小不能超过65535。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列产品(Ascend 310P处理器)

调用示例

1.单算子调用:
# Atlas 推理系列产品(Ascend 310P处理器):在单算子模式下不支持使能高带宽的x2数据排布,如果想追求极致性能,请使用图模式
import torch
import torch_npu
import logging
import os

cpu_x1 = torch.randint(-5, 5, (1, 256, 768), dtype=torch.int8)
cpu_x2 = torch.randint(-5, 5, (31, 768, 16), dtype=torch.int8)
scale = torch.randn(16, dtype=torch.float32)
offset = torch.randn(16, dtype=torch.float32)
bias = torch.randint(-5, 5, (31, 1, 16), dtype=torch.int32)
# Method 1:You can directly call npu_quant_matmul
npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset=offset.npu(), bias=bias.npu())

# Method 2: You can first call npu_trans_quant_param to convert scale and offset from float32 to int64 when output dtype is torch.int8 or torch.float16
scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu())
npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale_1,  bias=bias.npu())


2.图模式:
2.1 通用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x1, x2, scale, offset, bias):
        scale_1 = torch_npu.npu_trans_quant_param(scale, offset)
        return torch_npu.npu_quant_matmul(x1, x2, scale_1, offset=offset, bias=bias)
cpu_model = MyModel()
model = cpu_model.npu()
cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8)
cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8)
scale = torch.randn(1, dtype=torch.float32)
offset = torch.randn(1, dtype=torch.float32)
bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32)
model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset.npu(), bias.npu())

2.2 Atlas 推理系列产品(Ascend 310P处理器)Atlas A2 训练系列产品不可用)高性能调用方式
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x1, x2, scale, offset, bias):
        return torch_npu.npu_quant_matmul(x1, x2.transpose(2,1), scale, offset=offset, bias=bias)
cpu_model = MyModel()
model = cpu_model.npu()
cpu_x1 = torch.randint(-1, 1, (15,1, 512), dtype=torch.int8).npu()
cpu_x2 = torch.randint(-1, 1, (15,512, 128), dtype=torch.int8).npu()
# Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,n,k) layout
cpu_x2_t_29 = torch_npu.npu_format_cast(cpu_x2.transpose(2,1).contiguous(), 29)
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
# Process scale from float32 to int64 offline to improve performance
scale_1 = torch_npu.npu_trans_quant_param(scale, offset)
model = torch.compile(cpu_model, backend=npu_backend, dynamic=False)
npu_out = model(cpu_x1, cpu_x2_t_29, scale_1, offset, bias)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词