LinearWeightQuant是对torch_npu.npu_weight_quant_batchmatmul接口的封装类,完成矩阵乘计算中的weight输入和输出的量化操作,支持pertensor、perchannel、pergroup多场景量化。
当前Atlas 推理系列加速卡产品仅支持perchannel量化。
torch_npu.contrib.module.LinearWeightQuant(in_features, out_features, bias=True, device=None, dtype=None, antiquant_offset=False, quant_scale=False, quant_offset=False, antiquant_group_size=0, inner_precise=0)
输出为Device侧Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch import torch_npu from torch_npu.contrib.module import LinearWeightQuant x = torch.randn((8192, 320),device='npu',dtype=torch.float16) weight = torch.randn((320, 256),device='npu',dtype=torch.int8) antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.float16) antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.float16) quantscale = torch.randn((1, 256),device='npu',dtype=torch.float) quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float) model = LinearWeightQuant(in_features=320, out_features=256, bias=False, dtype=torch.float16, antiquant_offset=True, quant_scale=True, quant_offset=True, antiquant_group_size=0, device=torch.device(f'npu:0') ) model.npu() model.weight.data = weight.transpose(-1, -2) model.antiquant_scale.data = antiquantscale.transpose(-1, -2) model.antiquant_offset.data = antiquantoffset.transpose(-1, -2) model.quant_scale.data = torch_npu.npu_trans_quant_param(quantscale, quantoffset) model.quant_offset.data = quantoffset out = model(x) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch import torch_npu import torchair as tng from torch_npu.contrib.module import LinearWeightQuant from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16) weight = torch.randn((320, 256),device='npu',dtype=torch.int8) antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) quantscale = torch.randn((1, 256),device='npu',dtype=torch.float) quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float) model = LinearWeightQuant(in_features=320, out_features=256, bias=False, dtype=torch.bfloat16, antiquant_offset=True, quant_scale=True, quant_offset=True, antiquant_group_size=0, device=torch.device(f'npu:0') ) model.npu() model.weight.data = weight.transpose(-1, -2) model.antiquant_scale.data = antiquantscale.transpose(-1, -2) model.antiquant_offset.data = antiquantoffset.transpose(-1, -2) model.quant_scale.data = quantscale model.quant_offset.data = quantoffset tng.experimental.inference.use_internal_format_weight(model) # 将ND的weight输入转为FRACTAL_NZ格式 model = torch.compile(model, backend=npu_backend, dynamic=False) out = model(x) |