下载
中文
注册

逐层蒸馏

本节详细介绍逐层蒸馏支持的层,接口调用流程和示例。

  • 支持做蒸馏和量化的算子:
    • torch.nn.Linear:复用层(共用weight和bias参数)不支持量化。
    • torch.nn.Conv2d:padding_mode为zeros才支持量化,复用层(共用weight和bias参数)不支持量化。
  • 支持做蒸馏的激活算子:
    • torch.nn.Relu
    • torch.nn.LeakyRelu
    • torch.nn.Sigmoid
    • torch.nn.Tanh
    • torch.nn.Softmax
  • 支持做蒸馏的归一化算子

    torch.nn.BatchNorm2d:复用层(共用weight和bias参数)不支持蒸馏。

逐层蒸馏示例请参见获取更多样例

接口调用流程

调用流程如图1所示。

图1 蒸馏接口调用流程
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现,逐层蒸馏特性主要分为4个部分,创建蒸馏配置、创建蒸馏模型、蒸馏量化模型和保存量化模型,详情如下:
  1. 创建蒸馏配置,用户首先构造PyTorch的原始模型,然后使用create_distill_config接口,将用户自定义的蒸馏配置与AMCT算法定义的蒸馏配置相结合,输出网络每一层的蒸馏配置信息。
  2. 创建蒸馏模型,调用create_distill_model接口对原始模型进行修改,根据蒸馏配置信息生成一个待蒸馏的量化模型。
  3. 蒸馏模型,调用distill接口,根据用户配置的模型推理和优化方法,结合蒸馏配置信息,对网络进行分块蒸馏,得到蒸馏后的性能优化的量化模型。
  4. 保存量化模型,最后调用save_distill_model接口保存已蒸馏优化的量化模型,包括可在ONNX执行框架ONNX Runtime环境中进行精度仿真的模型文件和可部署在昇腾AI处理器的模型文件。

调用示例

  • 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
  • 如下示例调用AMCT的部分,函数入参请根据实际情况进行调整。
  1. 导入AMCT包,并通过安装后处理中的环境变量设置日志级别。
    1
    import amct_pytorch as amct
    
  2. (可选,由用户补充处理)在PyTorch原始环境中验证推理脚本及环境。

    建议使用原始待量化的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。

    推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。

    1
    user_do_inference_torch(ori_model, test_data, test_iterations)
    
  3. 调用AMCT,蒸馏模型。
    1. 调用create_distill_config接口生成蒸馏配置。输出的蒸馏配置中包含用户自定义的和自动查找到的蒸馏结构。
      1
      2
      3
      4
      5
      6
      config_file = './tmp/config.json'
      simple_cfg = './distill.cfg'
      amct.create_distill_config(config_file=config_file,
                                   model=ori_model,
                                   input_data=ori_model_input_data,
                                   config_defination=simple_cfg)
      
    2. 调用create_distill_model接口创建蒸馏模型。
      对待蒸馏的浮点模型进行量化,将浮点模型中的待压缩算子替换为CANN量化算子。
      1
      2
      3
      4
      compress_model = amct.create_distill_model(
                       config_file=config_file,
                       model=ori_model,
                       input_data=ori_model_input_data)
      
    3. 逐层蒸馏。

      调用distill接口进行逐层蒸馏。针对配置中的蒸馏结构进行蒸馏。

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      distill_model = amct.distill(
                      model=ori_model,
                      compress_model
                      config_file=config_file,
                      train_loader,
                      epochs=1,
                      lr=1e-3,
                      sample_instance=None, 
                      loss=loss,
                      optimizer=optimizer)
      
    4. 保存蒸馏模型。
      调用save_distill_model接口,插入AscendQuant、AscendDequant等算子,保存为蒸馏模型。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      amct.save_distill_model(
                       model, 
                       "./model/distilled"
                       input_data,
                       record_file="./results/records.txt"
                       input_names=['input'],
                       output_names=['output'],
                       dynamic_axes={'input':{0: 'batch_size'},
                                     'output':{0: 'batch_size'}})
      
  4. (可选,由用户补充处理)基于ONNX Runtime的环境,使用蒸馏后模型(quant_model)在测试集(test_data)上做推理,测试蒸馏后仿真模型的精度。

    使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)