下载
中文
注册

量化流程

本节详细介绍量化感知训练支持的量化层,接口调用流程和示例。

INT4量化只有Atlas 推理系列产品支持。

量化感知训练根据量化后低比特位宽大小分为INT8量化和INT4量化,具体使用哪种量化方式,由量化感知训练简易配置文件说明中的dst_type参数控制。量化示例请参见样例列表

量化感知训练当前仅支持对FP32数据类型的网络模型进行量化。用户可以根据实际情况选择进行INT8量化还是INT4量化。量化感知训练支持量化的层以及约束如下:

表1 量化感知训练支持的层以及约束

支持的层类型

约束

备注

torch.nn.Linear

-

  • 复用层(共用weight和bias参数)不支持量化。
  • 所列层支持INT8量化和INT4量化。

torch.nn.Conv2d

  • padding_mode为zeros
  • 由于硬件约束,原始模型中输入通道数Cin<=16时不建议进行量化感知训练,否则可能会导致量化后的部署模型推理时精度下降
  • 只支持input data的shape为(N, Cin, Hin, Win)

torch.nn.ConvTranspose2d

  • padding_mode为zeros
  • 由于硬件约束,原始模型中输入通道数Cin<=16时不建议进行量化感知训练,否则可能会导致量化后的部署模型推理时精度下降
  • 只支持input data的shape为(N, Cin, Hin, Win)

接口调用流程

量化感知训练接口调用流程如图1所示。

图1 接口调用流程
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现:
  1. 用户首先构造PyTorch的原始模型,调用create_quant_retrain_config接口生成量化配置文件。
  2. 调用create_quant_retrain_model接口对原始模型进行修改,修改后的模型中插入了数据量化、权重量化等相关算子,用于计算量化相关参数。
  3. 对修改后的模型进行训练,如果训练未中断,将训练后的模型进行推理,进行推理的过程中,会将量化因子写出到record文件中。

    如果训练过程中断,则可基于保存的pth模型参数和量化配置文件,重新调用restore_quant_retrain_model接口,输出修改后的retrain network,继续进行量化感知的训练,然后进行推理。

  4. 调用save_quant_retrain_model接口,插入AscendQuant、AscendDequant等量化算子,保存量化模型。

调用示例

  1. 基于PyTorch环境进行训练,当前仅支持distribution模式(即DistributedDataParallel模式)的多卡训练,不支持DataParallel模式的多卡训练,使用DP模式训练会报错。
  2. 调用AMCT的部分,函数入参可以根据实际情况进行调整。量化感知训练基于用户的训练过程,请确保已经有基于PyTorch环境进行训练的脚本,并且训练后的精度正常。
  3. 使用AMCT的量化感知训练特性时,如果训练过程卡死,请检查当前服务器是否有其他ONNX Runtime程序在运行(可以用top命令查看服务器所有进程),如果有,请先暂定其他ONNX Runtime程序,重新执行量化感知训练。
  4. 参考本章节进行量化,模型中存在Pytorch自定义算子时,可能存在无法导出生成ONNX模型,从而导致量化失败的问题。具体报错信息如下:'Model cannot be quantized for it cannot be export to onnx!' 。此时,您可以参考单算子模式的量化感知训练章节,进行单算子模式的量化。
  5. 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
  1. 导入AMCT包,并通过安装后处理>AMCT(PyTorch)中的环境变量设置日志级别。
    1
    import amct_pytorch as amct
    
  2. (可选,由用户补充处理)建议使用原始待量化的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。

    推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。

    1
    2
    3
    ori_model.load()
    # 测试模型
    user_test_model(ori_model, test_data, test_iterations)
    
  3. 调用AMCT,执行量化流程。
    1. 生成量化配置。

      实现该步骤前,应先恢复训练好的参数,如2中的ori_model.load()。

      1
      2
      3
      4
      5
      6
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      amct.create_quant_retrain_config(config_file=config_file,
                                       model=ori_model,
                                       input_data=ori_model_input_data,
                                       config_defination=simple_cfg)
      
    2. 修改模型。
      在模型ori_model插入数据量化、权重量化等相关算子,用于计算量化相关参数,然后保存为新的训练模型retrain_model。
      1
      2
      3
      4
      5
      record_file = './tmp/record.txt'
      quant_retrain_model = amct.create_quant_retrain_model(config_file=config_file,
      						      model=ori_model,
      						      record_file=record_file,
      						      input_data=ori_model_input_data)
      
    3. (由用户补充处理)使用修改后的图,创建反向梯度,在训练集上做训练,训练量化因子。
      1. 使用修改后的图,创建反向梯度。该步骤需要在3.b后执行。
        1
        optimizer = user_create_optimizer(quant_retrain_model)
        
      2. 从训练好的checkpoint恢复模型,并训练模型。

        注意:从训练好的checkpoint恢复模型参数后再训练;训练中保存的参数应该包括量化因子。

        1
        2
        quant_pth = './ckpt/user_model'
        user_train_model(optimizer, quant_retrain_model, train_data)
        
      3. 训练完成后,执行推理,计算量化因子并保存。
        1
        user_infer_graph(quant_retrain_model)
        
    4. 保存量化模型。
      根据量化因子以及用户重训练好的模型,调用save_quant_retrain_model接口,插入AscendQuant、AscendDequant等算子,保存为量化模型。
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=quant_retrain_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      
  4. (可选,由用户补充处理)基于ONNX Runtime的环境,使用量化后模型(quant_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。

    使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)
    

如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:

  1. 导入AMCT包,并通过安装后处理>AMCT(PyTorch)中的环境变量设置日志级别。
    1
    import amct_pytorch as amct
    
  1. 准备原始模型。
    1
    ori_model= user_create_model()
    
  2. 调用AMCT,恢复量化训练流程。
    1. 修改模型,在模型ori_model插入量化相关的算子,保存为新的训练模型retrain_model。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      record_file = './tmp/record.txt'
      quant_pth_file = './ckpt/user_model_newest.ckpt'
      quant_retrain_model = amct.restore_quant_retrain_model(config_file=config_file,
      						       model=ori_model,
      						       record_file=record_file,
      	                                               input_data=ori_model_input_data,
      	                                               pth_file=quant_pth_file)
      
    2. (由用户补充处理)使用修改后的图,创建反向梯度,在训练集上做训练,训练量化因子。
      1. 使用修改后的图,创建反向梯度。该步骤需要在3.a后执行。
        1
        optimizer = user_create_optimizer(retrain_model)
        
      2. 从训练好的checkpoint恢复模型,并训练模型。

        注意:从训练好的checkpoint恢复模型参数后再训练;训练中保存的参数应该包括量化因子。

        1
        user_train_model(optimizer, retrain_model, train_data)
        
      3. 训练完成后,执行推理,计算量化因子并保存。
        1
        user_infer_graph(train_graph, retrain_ops[-1].output_tensor)
        
    3. 保存模型。
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=ori_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      
  3. (可选,由用户补充处理)基于ONNX Runtime的环境,使用量化后模型(quant_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。

    使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)