文档
注册
评分
提单
论坛
小AI

torch_npu.npu_weight_quant_batchmatmul

功能描述

该接口用于实现矩阵乘计算中的weight输入和输出的量化操作,支持pertensor,perchannel,pergroup多场景量化(Atlas 推理系列加速卡产品当前仅支持perchannel)。

接口原型

npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0) -> Tensor

参数说明

  • x : Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16, 支持带transpose的非连续的Tensor,支持输入维度为两维(M,K) ;Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast。
  • weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND,数据类型支持INT8, 支持带transpose的非连续的Tensor,支持输入维度为两维(K,N);Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast,维度需与x保持一致。
  • antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, );Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
  • antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, );Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
  • quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化 。数据格式支持ND,数据类型支持FLOAT32/INT64,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂未使用此参数。
  • quant_offset:Device侧Tensor类型,量化的offset,用于输出矩阵的量化 。数据格式支持ND,数据类型支持FLOAT32,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂未使用此参数。
  • bias:Device侧Tensor类型, 即矩阵乘中的bias,数据格式支持ND,数据类型支持FLOAT16/FLOAT32, 不支持非连续的Tensor,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
  • antiquant_group_size:int类型, 用于控制pergroup场景下的group大小,当前默认为0,预留参数,暂未使用

输出说明

输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。

约束说明

  • x和weight后两维必须为(M,K)和(K,N)格式,M、K、N的范围为[1, 65535];在x为非转置时,M不受65535的限制。
  • 不支持空Tensor输入。
  • antiquant_scale和antiquant_offset的输入shape要保持一致。
  • quant_scale和quant_offset的输入shape要保持一致。
  • quant_offset不能独立于quant_scale存在。
  • 当x输入类型为BFLOAT16类型时候,bias的输入类型为FLOAT32;当x输入类型为FLOAT16类型时候,bias的输入类型为FLOAT16。
  • 如需传入INT64数据类型的quantscale, 需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为FLOAT32的quant_scale和quant_offset转换为数据类型为INT64的quant_scale输入。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的芯片型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列加速卡产品

调用示例

单算子模式:
import torch
import torch_npu

cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())

图模式:
import torch
import torch_npu
import  torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)

cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
        return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)

cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词