torch_npu.contrib.module.QuantConv2d

功能描述

QuantConv2d是对torch_npu.npu_quant_conv2d接口的封装类,为用户提供Conv2d算子量化相关功能,实现的计算公式如下:

接口原型

torch_npu.contrib.module.QuantConv2d(in_channels, out_channels, kernel_size, output_dtype, stride=1, padding=0, dilation=1, groups=1, bias=True, offset=False, offset_x=0, round_mode="rint", device=None, dtype=None)

参数说明

输入说明

quant_conv2d_input(计算输入):Device侧的Tensor类型,数据类型支持INT8,数据格式支持NCHW,shape为4维。

变量说明

输出说明

一个Tensor类型的输出,代表QuantConv2d的计算结果:

约束说明

支持的型号

调用示例

图模式调用:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch_npu
import torchair as tng

from torchair.configs.compiler_config import CompilerConfig
from torch_npu.contrib.module import QuantConv2d


torch_npu.npu.set_device(0)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)

fmap = torch.randint(-1, 1, (1, 1, 64, 64), dtype=torch.int8).npu()
weight = torch.randint(-1, 1, (1, 1, 3, 3), dtype=torch.int8).npu()
scale = torch.randint(-1, 1, (1,), dtype=torch.int64).npu()
bias = torch.randint(-1, 1, (1,), dtype=torch.int32).npu()

model = QuantConv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), output_dtype=torch.float16).npu()
model.weight.data = weight
model.scale.data = scale
model.bias.data = bias

with torch.no_grad():
    static_graph_model = torch.compile(model, backend=npu_backend, dynamic=False)
    output = static_graph_model(fmap)
print("static graph result: ", output)