量化感知训练
QAT(Quantization-Aware Training)即量化感知训练,量化感知训练会重新训练量化模型,从而减小模型大小,并且加快推理过程。当前支持对PyTorch框架的CNN类模型进行量化,并将量化后的模型保存为.onnx文件,量化过程中,需要用户自行提供模型与数据集,调用API接口完成模型的量化调优。
目前支持对包括但不限于表1 已验证模型列表中的模型进行模型量化感知训练。
类型 |
名称 |
---|---|
图像分类 |
Densenet121 |
InceptionV3 |
|
MobileNetV2 |
|
Resnet18 |
|
Resnet34 |
|
Resnet50 |
|
ResneXt50 |
|
Vgg16 |
|
Vgg19 |
|
Transformer图像分类 |
Swin-tiny |
Swin-base |
|
Vit-base |
|
NLP模型 |
Albert |
Bert-base |
|
Bert-large |
|
目标检测 |
FCOS |
YoloV3 |
|
YoloV5 |
|
YoloV7 |
|
SSD-VGG |
前提条件
操作步骤
- 用户需自行准备模型、训练脚本和数据集,本样例以PyTorch框架的Resnet50和数据集ImageNet为例。
- 编辑训练脚本pytorch_resnet50_apex.py文件,导入如下接口。
from modelslim.pytorch.quant.qat_tools import qsin_qat, QatConfig, get_logger
- 在优化器初始化之前调用“qsin_qat”函数,将量化后模型替换为“qsin_qat”的输出模型。请参考QatConfig和qsin_qat进行配置。同时在训练代码中,需注意保存伪量化模型权重ckpt文件,在导出量化onnx使用。
quant_config = QatConfig(grad_scale=0.001) quant_logger = get_logger() model = qsin_qat(model, quant_config, quant_logger).to(device_calc) #根据实际情况配置待量化模型实例、量化配置和量化输出日志,注意需把模型按照原训练流程部署在NPU设备
- 调用原训练流程进行单卡训练,执行train_full_1p.sh启动单卡训练任务。
bash ./test/train_full_1p.sh --data_path=/datasets/imagenet #请根据实际情况配置数据集路径
- 导出量化后的ONNX模型。在伪量化模型权重ckpt文件保存后,新建quant_deploy.py文件,添加如下代码,调用“save_qsin_qat_model”函数,请参考save_qsin_qat_model进行配置。
import argparse import os import torch import models.image_classification.resnet as nvmodels # 初始化模型 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='onnx bs') parser.add_argument('--pretrained', default="./org_model_best.pth.tar", type=str, help='use pre-trained model') parser.add_argument('--quant_ckpt', default="./checkpoint_77.244_asym.pth.tar", type=str, help='use pre-trained model') args = parser.parse_args() model = nvmodels.build_resnet("resnet50", "classic", is_training=False) pretrained_dict = torch.load(args.pretrained, map_location='cpu')["state_dict"] model.load_state_dict(pretrained_dict, strict=False) #保存量化后的onnx模型 from modelslim.pytorch.quant.qat_tools import save_qsin_qat_model #根据实际情况配置导出后模型文件名(文件后缀需为.onnx)、输入的shape、伪量化模型权重和onnx的输入名称 save_onnx_name='./resnet50.onnx' dummy_input = torch.ones([args.batch_size, 3, 224, 224]).type(torch.float32) saved_ckpt = args.quant_ckpt input_names=['input1'] save_qsin_qat_model(model, save_onnx_name, dummy_input, saved_ckpt, input_names)
- 执行量化脚本,获取量化后的onnx模型。
python3 quant_deploy.py
父主题: 模型量化