torch_npu.npu_weight_quant_batchmatmul
功能描述
该接口用于实现矩阵乘计算中的weight输入和输出的量化操作,支持pertensor、perchannel、pergroup多场景量化。
当前Atlas 推理系列加速卡产品仅支持perchannel量化。
接口原型
1 | npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0, int inner_precise=0) -> Tensor |
参数说明
- x : Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(M,K) 。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16。Atlas A3 训练系列产品 :数据类型支持FLOAT16、BFLOAT16。- Atlas 推理系列加速卡产品:数据类型仅支持FLOAT16。
- weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND、FRACTAL_NZ,支持带transpose的非连续的Tensor,支持输入维度为两维(K,N),维度需与x保持一致。perchannel场景为提高性能,推荐使用transpose后的weight输入,可以参考调用示例。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持INT8、INT32(通过INT32承载INT4的输入,可以参考torch_npu.npu_convert_weight_to_int4pack调用示例)。Atlas A3 训练系列产品 :数据类型支持INT8、INT32(通过INT32承载INT4的输入,可以参考torch_npu.npu_convert_weight_to_int4pack调用示例)。- Atlas 推理系列加速卡产品:数据类型支持INT8。weight FRACTAL_NZ格式传入只在图模式有效,需依赖接口torch_npu.npu_format_cast完成数据格式从ND到FRACTAL_NZ的转换,可参考调用示例。
- antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化,数据格式支持ND。支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, )。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16、INT64。pergroup场景shape要求为(ceil_div(K,antiquant_group_size), N)。- 若数据类型为FLOAT16、BFLOAT16, 其数据类型需要和x保持一致。
- 若数据类型为INT64,则x的数据类型必须为FLOAT16且不带transpose输入,同时weight的数据类型必须为INT8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持perchannel场景,M范围为[1,96],且K和N要求64对齐。
Atlas A3 训练系列产品 :数据类型支持FLOAT16、BFLOAT16、INT64。pergroup场景shape要求为(ceil_div(K,antiquant_group_size), N)。- 若数据类型为FLOAT16、BFLOAT16, 其数据类型需要和x保持一致。
- 若数据类型为INT64,则x的数据类型必须为FLOAT16且不带transpose输入,同时weight的数据类型必须为INT8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持perchannel场景,M范围为[1,96],且K和N要求64对齐。
- Atlas 推理系列加速卡产品:数据类型支持FLOAT16,其数据类型需要和x保持一致。
- antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化,为可选参数,默认值为None,数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, )。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16、INT32。pergroup场景shape要求为(ceil_div(K,antiquant_group_size), N)。- 若数据类型为FLOAT16、BFLOAT16,其数据类型需要和antiquant_scale保持一致。
- 若数据类型为INT32,则antiquant_scale的数据类型必须为INT64。
Atlas A3 训练系列产品 :数据类型支持FLOAT16、BFLOAT16、INT32。pergroup场景shape要求为(ceil_div(K,antiquant_group_size), N)。- 若数据类型为FLOAT16、BFLOAT16,其数据类型需要和antiquant_scale保持一致。
- 若数据类型为INT32,则antiquant_scale的数据类型必须为INT64。
- Atlas 推理系列加速卡产品:数据类型仅支持FLOAT16,其数据类型需要和antiquant_scale保持一致。
- quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持FLOAT32、INT64,数据格式支持ND,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。当antiquant_scale的数据类型为INT64时,此参数必须为空。
- Atlas 推理系列加速卡产品:暂不支持此参数。
- quant_offset: Device侧Tensor类型,量化的offset,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持FLOAT32,数据格式支持ND,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。当antiquant_scale的数据类型为INT64时,此参数必须为空。
- Atlas 推理系列加速卡产品:暂不支持此参数。
- bias:Device侧Tensor类型, 即矩阵乘中的bias,可选参数,默认值为None,数据格式支持ND,数据类型支持FLOAT16、FLOAT32, 不支持非连续的Tensor,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
- antiquant_group_size:int类型, 用于控制pergroup场景下的group大小,为可选参数,默认值为0。传入值的范围为[32,K-1]且值要求是32的倍数。
- Atlas 推理系列加速卡产品:暂不支持此参数。
- inner_precise: int类型,计算模式选择, 默认为0。0表示高精度模式,1表示高性能模式,可能会影响精度。当weight以INT32类型且以FRACTAL_NZ格式输入,M不大于16的pergroup场景下可以设置为1,提升性能。其他场景不建议使用高性能模式。
输出说明
输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。当输入weight为FRACTAL_NZ格式时暂不支持单算子调用,只支持图模式调用。
- x和weight后两维必须为(M,K)和(K,N)格式,K、N的范围为[1, 65535];在x为非转置时,M的范围为[1, 2^31-1],在x为转置时,M的范围为[1, 65535]。
- 不支持空Tensor输入。
- antiquant_scale和antiquant_offset的输入shape要保持一致。
- quant_scale和quant_offset的输入shape要保持一致,且quant_offset不能独立于quant_scale存在。
- 当x输入类型为BFLOAT16类型时候,bias的输入类型为FLOAT32;当x输入类型为FLOAT16类型时候,bias的输入类型为FLOAT16。
- 如需传入INT64数据类型的quant_scale,需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为FLOAT32的quant_scale和quant_offset转换为数据类型为INT64的quant_scale输入,可参考调用示例。
- 当输入weight为FRACTAL_NZ格式且类型为INT32时,perchannel场景需满足weight为转置输入;pergroup场景需满足x为转置输入,weight为非转置输入,antiquant_group_size为64或128,K为antiquant_group_size对齐,N为64对齐。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 Atlas A3 训练系列产品 - Atlas 推理系列加速卡产品
调用示例
- 单算子模式调用
- weight非transpose场景
1 2 3 4 5 6 7 8 9 10 11
import torch import torch_npu cpu_x = torch.randn((8192, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16) cpu_quantscale = torch.randn((1, 256),dtype=torch.float32) cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32) quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu()) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
- weight transpose场景
1 2 3 4 5 6 7
import torch import torch_npu cpu_x = torch.randn((96, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8) cpu_antiquantscale = torch.randn((256,1),dtype=torch.float16) cpu_antiquantoffset = torch.randint(-128, 127, (256,1), dtype=torch.float16) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu().transpose(-1, -2), cpu_antiquantscale.npu().transpose(-1, -2), cpu_antiquantoffset.npu().transpose(-1, -2))
- antiquant_scale INT64输入场景
1 2 3 4 5 6 7 8 9
import torch import torch_npu cpu_x = torch.randn((96, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8) cpu_antiquantscale = torch.randn((256),dtype=torch.float16) # 构建int64类型的scale参数 antiquant_scale = torch_npu.npu_trans_quant_param(cpu_antiquantscale.to(torch.float32).npu()).reshape(256, 1) cpu_antiquantoffset = torch.randint(-128, 127, (256, 1), dtype=torch.int32) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.transpose(-1,-2).npu(), antiquant_scale.transpose(-1,-2).npu(), cpu_antiquantoffset.transpose(-1,-2).npu())
- weight非transpose场景
- 图模式调用
- weight输入为ND格式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# 图模式 import torch import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16) cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size): return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size) cpu_model = MyModel() model = cpu_model.npu() model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
- Atlas 推理系列加速卡产品:weight输入为FRACTAL_NZ格式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
import torch_npu import torch from torchair.configs.compiler_config import CompilerConfig import torchair as tng config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) class NPUQuantizedLinearA16W8(torch.nn.Module): def __init__(self, weight, antiquant_scale, antiquant_offset, quant_offset=None, quant_scale=None, bias=None, transpose_x=False, transpose_weight=True, w4=False): super().__init__() self.dtype = torch.float16 self.weight = weight.to(torch.int8).npu() self.transpose_weight = transpose_weight if self.transpose_weight: self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29) else: self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz self.bias = None self.antiquant_scale = antiquant_scale self.antiquant_offset = antiquant_offset self.quant_offset = quant_offset self.quant_scale = quant_scale self.transpose_x = transpose_x def forward(self, x): x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x, self.weight.transpose(0, 1), self.antiquant_scale.transpose(0, 1), self.antiquant_offset.transpose(0, 1), self.quant_scale, self.quant_offset, self.bias) return x m, k, n = 4, 1024, 4096 cpu_x = torch.randn((m, k),dtype=torch.float16) cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8) cpu_weight = cpu_weight.transpose(-1, -2) cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16) cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2) cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2) model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu()) model = torch.compile(model, backend=npu_backend, dynamic=True) out = model(cpu_x.npu())
- weight输入为ND格式
父主题: torch_npu