下载
中文
注册

训练后量化(PyTorch)

概述

训练后量化工具需要用户提供PyTorch训练脚本或者pth文件,工具可自动对模型中的卷积和线性层(torch.nn.Linear和torch.nn.Conv2d)进行识别并量化,最终导出量化后的onnx模型,量化后的模型可以在推理服务器上运行,达到提升推理性能的目的。量化过程中用户需自行提供模型与数据集,调用API接口完成模型的量化调优。

自动混合精度量化算法

为了提升量化精度,训练后量化(PyTorch)算法内置了自动混合精度的模块,自动识别并回退量化敏感层为浮点计算,避免量化敏感层对精度造成较大损失。算法核心是:计算每个量化层量化前后输出的MSE,根据MSE的排序来衡量每一个量化层的量化敏感性,自动回退MSE最大的部分敏感层,从而提升量化的精度。

精度保持策略

为了进一步降低量化精度损失,训练后量化(PyTorch)工具内集成了多种精度保持策略,对权重的量化参数和取证方式进行优化。

  • Easy Quant权重优化方法:利用输出相似性优化量化参数,减少输入输出张量的量化误差,推荐在Data-Free模式下使用,通常能够起到较好的改善效果。
  • ADMM权重优化方法:使用交替优化的方法,对权重的量化参数进行迭代更新优化,推荐在Label-Free模式下使用,适当改善量化效果。
  • Rounding取整优化:在量化中普通取整不是最优解,使用自适应取整的方式优化权重的取整能提高量化精度,推荐在Label-Free模式下使用,适当改善量化效果。

前提条件

  • 已参考环境准备,完成CANN开发环境的部署及Python环境变量配置。
  • 训练后量化前须执行命令安装依赖。

    如下命令如果使用非root用户安装,需要在安装命令后加上--user,例如:pip3 install onnx --user

    pip3 install numpy              #需大于等于1.21.0版本
    pip3 install onnx               #需大于等于1.11.0版本
    pip3 install torch==1.11.0      #支持1.8.1和1.11.0,须为CPU版本的torch
    pip3 install onnx-simplifier    #需大于等于0.3.10版本

操作步骤

  1. 本样例以ResNet50为例。新建量化脚本样例resnet_quant.py,并进入{CANN包安装路径}/ascend-toolkit/latest/tools/modelslim/pytorch/quant/ptq_tools/目录,获取readme.md中的调用示例,拷贝至量化脚本样例resnet_quant.py中。

    量化脚本中需用户关注以下步骤,可以根据实际情况配置。

  2. 用户需自行准备在imagenet数据集上训练后保存的PyTorch训练脚本或pth模型,本样例使用ResNet50可以通过以下方式导入,用户也可以自定义导入。
    import torchvision
    model = torchvision.models.resnet50(pretrained=True)
    model.eval()
  3. 导入训练后量化接口。
    from modelslim.pytorch.quant.ptq_tools import QuantConfig, Calibrator
  4. (可选)调整日志输出等级,启动调优任务后,将打屏显示量化调优的日志信息。
    from modelslim import set_logger_level
    set_logger_level("info")        #根据实际情况配置
  5. 使用QuantConfig接口,配置量化参数,生成量化配置实例,请参考QuantConfig进行配置。
    disable_names = []
    input_shape = [1, 3, 224, 224] 
    keep_acc={'admm': [False, 1000], 'easy_quant': [False, 1000], 'round_opt': False}
    quant_config = QuantConfig(
        disable_names=disable_names,  # 手动回退的量化层名称,如精度太差,推荐回退量化敏感层
        amp_num=0,                    # 混合精度量化回退层数
        input_shape=input_shape,      # 模型输入的shape,用于Data-Free量化构造虚拟数据
        keep_acc=keep_acc,            # 精度保持策略
        sigma=25,                     # 大于0使用sigma统计方法;传入0值使用min-max统计方法
        )
  6. 调用Calibrator,通过Calibrator类封装量化算法,请参考Calibrator进行配置。
    calibrator = Calibrator(model, quant_config)
  7. 使用run接口,执行量化过程。
    calibrator.run()
  8. 使用export_quant_onnx接口,导出量化后onnx模型,请参考export_quant_onnx进行配置。
    calibrator.export_quant_onnx("resnet50", "./output", ["input.1"])
  9. 启动模型量化调优任务,并在步骤8指定的输出目录获取一个量化完成的模型。
    python3 resnet_quant.py