下载
中文
注册

基于精度的自动量化

基于精度的自动量化是为了方便用户在对量化精度有一定要求时所使用的功能。该方法能够在保证用户所需的模型精度前提下,自动搜索模型的量化配置并执行训练后量化的流程,最终生成满足精度要求的量化模型。

基于精度的自动量化基本原理与手工量化相同,但是用户无需手动调整量化配置文件,大大简化了优化流程,提高量化效率。量化示例请参见样例列表

接口调用流程

接口调用流程如图1所示。

图1 接口调用流程

主要流程如下:

  1. 调用create_quant_config生成量化配置文件,然后调用accuracy_based_auto_calibration进行基于精度的自动量化。
  2. 调用accuracy_based_auto_calibration中由用户传入的evaluator实例进行精度测试,得到原始模型精度。

    该过程还会调用accuracy_based_auto_calibration中的量化策略strategy模块,输出初始化的quant config量化配置文件,该文件记录所有层都可以进行量化。

  3. 使用用户传入的初始量化配置文件(1中调用create_quant_config生成的)对模型进行训练后量化,得到量化后fake quant模型的精度。
  4. 原始模型精度与量化后fake quant模型精度进行比较,如果精度达标,则输出量化后的部署模型和fake quant模型,如果不达标,则进行基于精度的自动量化流程:
    1. 进行原始PyTorch网络的推理, dump出每一层的输入activation数据,缓存起来;
    2. 利用训练后量化的量化因子构造量化层的单算子网络,利用缓存的activation数据计算量化后fake quant单算子网络的输出数据和原始PyTorch单算子网络输出的余弦相似度
    3. 将余弦相似度的列表传给accuracy_based_auto_calibration中的量化策略strategy模块,strategy模块基于2中生成的初始化的量化配置文件,输出回退某些层后的新的quant config量化配置文件
    4. 根据quant config量化配置文件重新进行训练后量化,得到回退后的fake quant模型
    5. 调用accuracy_based_auto_calibration中的evaluator模块进行回退后的fake quant模型精度测试,查看精度是否达标:
      • 如果达标,则输出回退后的fake quant模型以及部署模型。
      • 如果不达标,则将余弦相似度排序最差的层回退,再次进行4.c,输出新的量化配置。
      • 如果回退所有层后精度仍不达标,则不生成量化模型。

accuracy_based_auto_calibration接口内部基于精度的自动量化流程如图2所示。

图2 自动量化流程

调用示例

本示例演示使用AMCT进行基于精度的自动量化流程。该过程需要用户实现一个模型推理得到精度的回调函数,由于AMCT需要基于回调函数返回的精度数据进行量化层的筛选,因此回调函数的返回数值应尽可能反映模型的精度。

  • 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
  • 如下示例调用AMCT的部分,函数入参请根据实际情况进行调整。
  1. 导入AMCT包。
    1
    2
    3
    import os
    import amct_pytorch as amct
    from amct_pytorch.common.auto_calibration import AutoCalibrationEvaluatorBase
    
  2. (由用户补充处理)使用原始待量化的模型和测试集,实现回调函数calibration()evaluate()metric_eval()

    上述回调函数的入参要和基类AutoCalibrationEvaluatorBase保持一致。其中:

    • calibration()完成校准的推理。
    • evaluate()完成模型的精度测试过程。
    • metric_eval()完成原始模型和量化fake quant模型的精度损失评估,当精度损失小于预期值时返回True,否则返回False。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    class ModelEvaluator(AutoCalibrationEvaluatorBase):
         # The evaluator for model
        def __init__(self, *args, **kwargs):
            # 成员变量初始化
            # 设置预期精度损失,此处请替换为具体的数值
            self.diff = expected_acc_loss
            pass
    
        def calibration(self, model_file, weights_file):
            # 进行模型的校准推理,推理的batch数要和量化配置的batch_num一致
            pass
    
        def evaluate(self, model_file, weights_file):
            # evaluate the input models, get the eval metric of model
            pass
    
        def metric_eval(self, original_metric, new_metric):
            # 评估原始模型精度和量化模型精度的精度损失是否满足预期,满足返回True, 精度损失数据;否则返回False, 精度损失数据
            loss = original_metric - new_metric
            if loss < self.diff:
                return True, loss
            return False, loss
    
  3. (由用户补充处理)实例化pytorch模型,得到模型的对象。
    1
    model = MyNet()
    
  4. 调用AMCT,进行基于精度的自动量化。
    1. 生成量化配置。
      1
      2
      3
      4
      5
      6
      7
      8
      config_json_file = './config.json'
      skip_layers = []
      batch_num = 1
      amct.create_quant_config(config_json_file, model, input_data,
                                skip_layers, batch_num)
      
      scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
      result_path = os.path.join(RESULT, 'model')
      
    2. 初始化Evaluator。
      1
      evaluator = AutoCalibrationEvaluator()
      
    3. 进行基于精度的量化配置自动搜索。
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      amct.accuracy_based_auto_calibration(
              model=model,
              model_evaluator=evaluator,
              config_file=config_json_file,
              record_file=record_file,
              save_dir=result_path,
              input_data=input_data,
              input_names=['input'],
              output_names=['output'],
              dynamic_axes={
                  'input': {0: 'batch_size'},
                  'output': {0: 'batch_size'}
              },
              strategy='BinarySearch',
              sensitivity='CosineSimilarity'
          )