下载
中文
注册

搜索流程

本节介绍自动混合精度搜索场景的接口调用流程和调用示例。

接口调用流程

接口调用流程如下图所示,蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现。

用户准备好TensorFlow的原始模型、自动混合精度配置文件和推理评估器(Evaluator),调用auto_mixed_precision_search,根据压缩率、量化位宽、量化敏感度以及计算复杂度信息,执行自动混合精度搜索,得到混合精度配置文件与可用于量化感知训练的简易配置文件。

其中Evaluator模块需要用户自定义,用来执行模型的推理,获取量化因子,dump数据(每一层的输入数据)等信息。

图1 调用流程

调用示例

本示例演示了使用AMCT进行自动混合精度搜索的流程,该过程需要用户实现一个模型推理和校准的评估器。

  • 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
  • 如下示例调用AMCT的部分,函数入参请根据实际情况进行调整。
  1. 导入AMCT包,设置日志级别。
    1
    2
    3
    import amct_tensorflow as amct
    from amct_tensorflow.common.auto_calibration import AutoCalibrationEvaluatorBase
    amct.set_logging_level(print_level="info", save_level="info")
    
  2. (由用户补充处理)实现一个模型的评估器,使用原始待量化的模型和测试集,实现回调函数calibration()、evaluate()。

    上述回调函数的入参要和基类AutoCalibrationEvaluatorBase保持一致。其中:

    • calibration()完成校准的推理。
    • evaluate()完成模型的前向推理过程。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    class ModelEvaluator(AutoCalibrationEvaluatorBase):
          # The evaluator for model
        def __init__(self, *args, **kwargs):
            # 做成员变量初始化
            # 设置预期精度损失,此处请替换为具体的数值
            self.diff = expected_acc_loss 
            pass
    
        def calibration(self, graph, outputs, batch_num):
            """ 对一张图做量化校准前向推理
            graph:tensorflow.Graph类型, 对图graph做前向推理
            outputs: 列表类型,图graph的输出,在推理过程需要获取的输出
            batch_num: int类型,前向推理的batch数目,与量化配置的batch_num一致
            """
            pass
    
        def evaluate(self, graph, outputs, iterations): # pylint: disable=R0914
            """ 对一张图做量化校准前向推理
            graph:tensorflow.Graph类型, 对图graph做前向推理
            outputs: 列表类型,图graph的输出,在推理过程需要获取的输出
            iterations: int类型,前向推理的batch数目
            """
            pass
    
  3. 调用AMCT工具,进行自动混合精度搜索。
    1. 初始化evaluator,可以使用2构造的ModelEvaluator,也可以使用工具提供的GraphEvaluator,请参见AMCT安装目录/amct_tensorflow/interface/evaluator.py
      1
      2
      3
      4
      5
      6
      evaluator = ModelEvaluator()
      或者
      evaluator = amct.GraphEvaluator(
          data_dir="./data/input_bin/", 
          input_shape="input:32,3,224,224", 
          data_types="float32")
      
    2. 进行自动混合精度搜索。
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      model_file = './model/user_model.pb'
      outputs = ['user_model_outputs0', 'user_model_outputs1']
      cfg_file = './configs/auto_mixed_precision.cfg'
      save_dir = './results/auto_mixed_precision'
      
      amct.auto_mixed_precision_search(
          model_file=model_file, 
          outputs=outputs, 
          amc_config=cfg_file, 
          save_dir=save_dir, 
          evaluator=evaluator, 
          sensitivity='MseSimilarity')