下载
中文
注册

torch_npu.contrib.module.LinearWeightQuant

功能描述

LinearWeightQuant是对torch_npu接口torch_npu.npu_weight_quant_batchmatmul的封装类,完成矩阵乘计算中的weight输入和输出的量化操作,支持pertensor、perchannel、pergroup多场景量化。

接口原型

torch_npu.contrib.module.LinearWeightQuant(in_features, out_features, bias=True, device=None, dtype=None, antiquant_offset=False, quant_scale=False, quant_offset=False, antiquant_group_size=0)

参数说明

  • in_features:int类型,伪量化matmul计算中的k轴的值。
  • out_features:int类型,伪量化matmul计算中的n轴的值。
  • bias:bool类型,可选参数,默认为True,代表是否需要bias计算参数。如果设置成False,则bias不会加入伪量化matmul的计算。
  • device:string类型,可选参数,用于执行model的device名称,默认为None。
  • dtype:torch支持的dtype类型,可选参数,默认为None,伪量化matmul运算中的输入x的dtype。
  • antiquant_offset:bool类型,可选参数,默认为False,代表是否需要antiquant_offset计算参数。如果设置成False,则weight矩阵反量化时无需设置offset。
  • quant_scale:bool类型,可选参数,默认为False,代表是否需要quant_scale计算参数。如果设置成False,则伪量化输出不会进行量化计算。
  • quant_offset:bool类型,可选参数,默认为False,代表是否需要quant_offset计算参数。如果设置成False,则对伪量化输出进行量化计算时无需设置offset。
  • antiquant_group_size:int类型,可选参数,用于控制pergroup场景下的group大小,当前默认为0。

输入说明

x:Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16,支持输入维度为两维(M,K) 。

变量说明

  • weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND、FRACTAL_NZ,数据类型支持INT8、INT32(通过INT32承载INT4的输入,可以参考torch_npu.npu_convert_weight_to_int4pack中的调用示例),支持非连续的Tensor,支持输入维度为两维(N, K)。
  • antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16、INT64。若数据类型为FLOAT16、BFLOAT16, 则需要和x的数据类型一致;若数据类型为INT64, 则x的数据类型必须为FLOAT16,且不带transpose输入,同时weight的数据类型必须为INT8,且带transpose输入,可以参考调用示例。支持非连续的Tensor,支持输入维度为两维(N, 1)或一维(N, )、(1, ),pergroup场景为(N,ceil_div(K,antiquant_group_size),数据类型为INT64时只支持perchannel场景,M范围为[1,96],且K和N要求64对齐。
  • antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16、INT32。若数据类型为FLOAT16、BFLOAT16,则必须和antiquant_scale的数据类型一致;若数据类型为INT32,则antiquant_scale的数据类型必须为INT64,可以参考调用示例。支持非连续的Tensor,支持输入维度为两维(N, 1)或一维(N, )、(1, ),pergroup场景为(N,ceil_div(K,antiquant_group_size))。
  • quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化。数据格式支持ND,数据类型支持FLOAT32、INT64,支持输入维度为两维(1, N) 或一维(N, )、(1, )。当antiquant_scale的数据类型为INT64时,此参数必须为空。
  • quant_offset:Device侧Tensor类型,量化的offset,用于输出矩阵的量化。数据格式支持ND,数据类型支持FLOAT32,支持输入维度为两维(1, N) 或一维(N, )、(1, )。当antiquant_scale的数据类型为INT64时,此参数必须为空。
  • bias:Device侧Tensor类型, 即矩阵乘中的bias,数据格式支持ND,数据类型支持FLOAT16、FLOAT32,支持非连续的Tensor,支持输入维度为两维(1, N) 或一维(N, )、(1, )。
  • antiquant_group_size:int类型,用于控制pergroup场景下的group大小,默认为0。

输出说明

输出为Device侧Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。

约束说明

  • 该融合算子仅在推理场景使用。
  • LinearWeightQuant传入参数支持范围和限制与torch_npu.npu_weight_quant_batchmatmul接口保持一致。
  • Atlas 推理系列产品中weight FRACTAL_NZ格式传入只在图模式有效,需要依赖接口tng.experimental.inference.use_internal_format_weight完成formatND到FRACTAL_NZ的转换,可参考4-2调用示例

支持的型号

  • Atlas A2训练系列产品/Atlas 800I A2推理产品
  • Atlas 推理系列产品

调用示例

  • 单算子调用
    import torch
    import torch_npu
    from torch_npu.contrib.module import LinearWeightQuant
    x = torch.randn((8192, 320),device='npu',dtype=torch.float16)
    weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
    antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.float16)
    antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.float16)
    quantscale = torch.randn((1, 256),device='npu',dtype=torch.float)
    quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float)
    model = LinearWeightQuant(in_features=320,
                              out_features=256,
                              bias=False,
                              dtype=torch.float16,
                              antiquant_offset=True,
                              quant_scale=True,
                              quant_offset=True,
                              antiquant_group_size=0,
                              device=torch.device(f'npu:0')
                              )
    model.npu()
    model.weight.data = weight.transpose(-1, -2)
    model.antiquant_scale.data = antiquantscale.transpose(-1, -2)
    model.antiquant_offset.data = antiquantoffset.transpose(-1, -2)
    model.quant_scale.data = torch_npu.npu_trans_quant_param(quantscale, quantoffset)
    model.quant_offset.data = quantoffset
    out = model(x)
  • 图模式调用(图模式目前仅支持PyTorch 2.1版本)
    import torch
    import torch_npu
    import torchair as tng
    from torch_npu.contrib.module import LinearWeightQuant
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
    weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
    antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
    antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
    quantscale = torch.randn((1, 256),device='npu',dtype=torch.float)
    quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float)
    model = LinearWeightQuant(in_features=320,
                              out_features=256,
                              bias=False,
                              dtype=torch.bfloat16,
                              antiquant_offset=True,
                              quant_scale=True,
                              quant_offset=True,
                              antiquant_group_size=0,
                              device=torch.device(f'npu:0')
                              )
    model.npu()
    model.weight.data = weight.transpose(-1, -2)
    model.antiquant_scale.data = antiquantscale.transpose(-1, -2)
    model.antiquant_offset.data = antiquantoffset.transpose(-1, -2)
    model.quant_scale.data = quantscale
    model.quant_offset.data = quantoffset
    tng.experimental.inference.use_internal_format_weight(model) # 将ND的weight输入转为FRACTAL_NZ格式
    model = torch.compile(model, backend=npu_backend, dynamic=False)
    out = model(x)