文档
注册
评分
提单
论坛
小AI

torch_npu.npu_weight_quant_batchmatmul

功能描述

该接口用于实现矩阵乘计算中的weight输入和输出的量化操作,支持pertensor、perchannel、pergroup多场景量化(Atlas 推理系列加速卡产品当前仅支持perchannel)。

接口原型

npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0) -> Tensor

参数说明

  • x : Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16, 支持带transpose的非连续的Tensor,支持输入维度为两维(M,K) ;Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast。
  • weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND,FRACTAL_NZ(仅Atlas 推理系列加速卡产品支持),数据类型支持INT8、INT32(通过INT32承载INT4的输入,可以参考torch_npu.npu_convert_weight_to_int4pack接口示例), 支持带transpose的非连续的Tensor,支持输入维度为两维(K,N);Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16,支持输入维度为2-6维,支持batch轴但不支持broadcast,维度需与x保持一致。perchannel场景为提高性能,推荐使用transpose后的weight输入。
  • antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, ),pergroup场景为(ceil_div(K,antiquant_group_size), N);Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
  • antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化 。数据格式支持ND,数据类型支持FLOAT16、BFLOAT16,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, ),pergroup场景为(ceil_div(K,antiquant_group_size), N);Atlas 推理系列加速卡产品上数据类型仅支持FLOAT16。
  • quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化。数据格式支持ND,数据类型支持FLOAT32、INT64,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂不支持此参数。
  • quant_offset: Device侧Tensor类型,量化的offset,用于输出矩阵的量化。数据格式支持ND,数据类型支持FLOAT32,支持输入维度为两维(1, N) 或 一维(N, )、(1, );Atlas 推理系列加速卡产品暂不支持此参数。
  • bias:Device侧Tensor类型, 即矩阵乘中的bias,数据格式支持ND,数据类型支持FLOAT16、FLOAT32, 不支持非连续的Tensor,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
  • antiquant_group_size:int类型, 用于控制pergroup场景下的group大小,当前默认为0。

输出说明

输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。

约束说明

  • x和weight后两维必须为(M,K)和(K,N)格式,M、K、N的范围为[1, 65535];在x为非转置时,M不受65535的限制。
  • 不支持空Tensor输入。
  • antiquant_scale和antiquant_offset的输入shape要保持一致。
  • quant_scale和quant_offset的输入shape要保持一致,且quant_offset不能独立于quant_scale存在。
  • 当x输入类型为BFLOAT16类型时候,bias的输入类型为FLOAT32;当x输入类型为FLOAT16类型时候,bias的输入类型为FLOAT16。
  • 如需传入INT64数据类型的quant_scale, 需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为FLOAT32的quant_scale和quant_offset转换为数据类型为INT64的quant_scale输入,可参考调用示例
  • weight FRACTAL_NZ格式传入需要依赖接口torch_npu.npu_format_cast,可参考调用示例

支持的PyTorch版本

  • PyTorch 2.3
  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的芯片型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列加速卡产品

调用示例

  • 单算子调用
    # 单算子模式:
    import torch
    import torch_npu
    
    cpu_x = torch.randn((8192, 320),dtype=torch.float16)
    cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8)
    cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16)
    cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16)
    cpu_quantscale = torch.randn((1, 256),dtype=torch.float32)
    cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32)
    quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu())
    npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
  • 图模式调用(图模式目前仅支持PyTorch 2.1版本)
    # 图模式
    import torch
    import torch_npu
    import  torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    
    cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
    cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
    cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
    cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
    
    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
            return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
    
    cpu_model = MyModel()
    model = cpu_model.npu()
    model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
    npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
    
    # Atlas 推理系列加速卡产品图模式,weight输入为FRACTAL_NZ格式:
    import torch_npu
    import torch
    from torchair.configs.compiler_config import CompilerConfig
    import torchair as tng
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    class NPUQuantizedLinearA16W8(torch.nn.Module):
        def __init__(self,
                     weight,
                     antiquant_scale,
                     antiquant_offset,
                     quant_offset=None,
                     quant_scale=None,
                     bias=None,
                     transpose_x=False,
                     transpose_weight=True,
                     w4=False):
            super().__init__()
    
            self.dtype = torch.float16
            self.weight = weight.to(torch.int8).npu()
            self.transpose_weight = transpose_weight
    
            if self.transpose_weight:
                self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29)
            else:
                self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz
    
            self.bias = None
            self.antiquant_scale = antiquant_scale
            self.antiquant_offset = antiquant_offset
            self.quant_offset = quant_offset
            self.quant_scale = quant_scale
            self.transpose_x = transpose_x
    
        def forward(self, x):
            x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x,
                                                       self.weight.transpose(0, 1),
                                                       self.antiquant_scale.transpose(0, 1),
                                                       self.antiquant_offset.transpose(0, 1),
                                                       self.quant_scale,
                                                       self.quant_offset,
                                                       self.bias)
            return x
    
    
    m, k, n = 4, 1024, 4096
    cpu_x = torch.randn((m, k),dtype=torch.float16)
    cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8)
    cpu_weight = cpu_weight.transpose(-1, -2)
    
    cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16)
    cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16)
    cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
    cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
    model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
    model = torch.compile(model, backend=npu_backend, dynamic=True)
    out = model(cpu_x.npu())
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词