文档
注册

单算子模式的量化感知训练

功能介绍

参考量化感知训练章节进行基础量化时,量化的内部处理逻辑需要将原始模型转换成ONNX模型,并在ONNX模型基础上进行图的修改操作,此时若模型中存在Pytorch自定义算子时,可能存在无法导出生成ONNX模型,从而导致量化失败的问题。

单算子模式的量化感知训练功能,提供由Pytorch原生算子转换生成的自定义QAT算子,基于该算子进行量化因子的重训练,量化因子作为算子参数保存在QAT单算子中,无需导出ONNX模型,可以避免上述量化感知训练方案中的算子导出异常问题。训练完成后,通过torch.onnx.export机制,建立QAT算子与ONNX原生算子的映射关系,将Pytorch模型中QAT算子计算获得的参数传递给ONNX原生量化算子,完成模型导出。简易示意图如下所示。

图1 单算子模式量化感知训练

该功能支持Training from scratch和Fine-tune两种使用方法:

  • Training from scratch:使用QAT算子直接构图,从零开始训练。您可以在模型构建脚本中调用单算子模式提供的直接构造接口构造QAT算子,使用该算子进行模型构建。
  • Fine-tune:在已有网络基础上,对待量化算子进行替换,相比于Training from scratch更为常用。如果您已经完成模型构建,您可以调用单算子模式提供的基于原生算子构造接口,进行QAT算子构造;之后可以参考下文样例中的算子替换方案,对网络模型中的待量化算子进行替换。

QAT 算子规格如下表所示,调用示例请参见获取更多样例

表1 QAT算子规格

待量化算子类型

替换后算子类型

限制

备注

torch.nn.Conv2d

Conv2dQAT

padding_mode为zeros

复用层(共用weight和bias参数)不支持量化。

torch.nn.ConvTranspose2d

ConvTranspose2dQAT

padding_mode为zeros

torch.nn.Conv3d

Conv3dQAT

padding_mode为zeros,

dilation_d为1

torch.nn.Linear

LinearQAT

不支持channel wise

样例参考

Training from scratch方式代码样例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT
from amct_pytorch.nn.module.quantization.linear import LinearQAT


# QAT算子量化配置项
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

# 使用QAT单算子构造Lenet网络
net = torch.nn.Sequential(
    Conv2dQAT(1, 6, kernel_size=5, padding=2, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    Conv2dQAT(6, 16, kernel_size=5, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    LinearQAT(16 * 5 * 5, 120, config=config), nn.Sigmoid(),
    LinearQAT(120, 84, config=config), nn.Sigmoid(),
    LinearQAT(84, 10, config=config)
)

# 训练
train(net, train_data, test_data)
# 导出中间模型
torch.onnx.export(model, data, 'inter_model.onnx')
# 导出fake quant模型与deploy模型
# 生成lenet_fake_quant_model.onnx,可在 ONNX 执行框架 ONNX Runtime 进行精度仿真的模型。
# 生成lelet_deploy_model.onnx,可在AI 处理器部署的模型文件。
convert_qat_model('inter_model.onnx', './outputs/lenet')
#使用fake quant模型进行精度仿真
validata_onnx('./outputs/lenet_fake_quant_model.onnx', val_data)

Fine-tune方式代码样例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torchvision.models.resnet import resnet101

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT

model = resnet101(pretrained=True)
# QAT算子量化配置项
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": True
    }
}

def _set_module(model, submodule_key, module):
    # 将模型中原生算子替换为qat算子
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)

for name, module in model.named_modules():
    # 遍历原图中各节点,将类型为Conv2d的torch原生算子替换为自定义QAT单算子
    if isinstance(module, torch.nn.Conv2d):
        qat_module = Conv2dQAT.from_float(
            module, config=config)
        _set_module(model, name, qat_module)

# 训练流程
train_and_val(model, train_data, test_data)
# 导出中间模型
torch.onnx.export(model, data, 'inter_model.onnx')
# 导出fake quant模型与deploy模型
# 生成resnet101_fake_quant_model.onnx,可在 ONNX 执行框架 ONNX Runtime 进行精度仿真的模型。
# 生成resnet101_deploy_model.onnx,可在AI 处理器部署的模型文件。
convert_qat_model('inter_model.onnx', './outputs/resnet101')
#使用fake quant模型进行精度仿真
validata_onnx('./outputs/resnet101_fake_quant_model.onnx', val_data)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词