torch_npu.npu_weight_quant_batchmatmul
功能描述
该接口用于实现矩阵乘计算中的weight输入和输出的量化操作,支持pertensor,perchannel,pergroup多场景量化(Atlas 推理系列加速卡产品当前仅支持perchannel)。
接口原型
npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0) -> Tensor
参数说明
- x : Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16, 支持带transpose的非连续的Tensor,支持输入维度为两维(M,K) ;Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast。
- weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND,数据类型支持INT8, 支持带transpose的非连续的Tensor,支持输入维度为两维(K,N);Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast,维度需与x保持一致。
- antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, );Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
- antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16/BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, );Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
- quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化 。数据格式支持ND,数据类型支持FLOAT32/INT64,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂未使用此参数。
- quant_offset:Device侧Tensor类型,量化的offset,用于输出矩阵的量化 。数据格式支持ND,数据类型支持FLOAT32,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂未使用此参数。
- bias:Device侧Tensor类型, 即矩阵乘中的bias,数据格式支持ND,数据类型支持FLOAT16/FLOAT32, 不支持非连续的Tensor,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
- antiquant_group_size:int类型, 用于控制pergroup场景下的group大小,当前默认为0,预留参数,暂未使用。
输出说明
输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。
约束说明
- x和weight后两维必须为(M,K)和(K,N)格式,M、K、N的范围为[1, 65535];在x为非转置时,M不受65535的限制。
- 不支持空Tensor输入。
- antiquant_scale和antiquant_offset的输入shape要保持一致。
- quant_scale和quant_offset的输入shape要保持一致。
- quant_offset不能独立于quant_scale存在。
- 当x输入类型为BFLOAT16类型时候,bias的输入类型为FLOAT32;当x输入类型为FLOAT16类型时候,bias的输入类型为FLOAT16。
- 如需传入INT64数据类型的quantscale, 需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为FLOAT32的quant_scale和quant_offset转换为数据类型为INT64的quant_scale输入。
支持的PyTorch版本
- PyTorch 2.2
- PyTorch 2.1
- PyTorch 1.11.0
支持的芯片型号
- Atlas A2 训练系列产品
- Atlas 推理系列加速卡产品
调用示例
单算子模式: import torch import torch_npu cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16) cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu()) 图模式: import torch import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16) cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size): return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size) cpu_model = MyModel() model = cpu_model.npu() model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
父主题: torch_npu