逐层蒸馏
本节详细介绍逐层蒸馏支持的层,接口调用流程和示例。
- 支持做蒸馏和量化的算子:
- torch.nn.Linear:复用层(共用weight和bias参数)不支持量化。
- torch.nn.Conv2d:padding_mode为zeros才支持量化,复用层(共用weight和bias参数)不支持量化。
- 支持做蒸馏的激活算子:
- torch.nn.Relu
- torch.nn.LeakyRelu
- torch.nn.Sigmoid
- torch.nn.Tanh
- torch.nn.Softmax
- 支持做蒸馏的归一化算子
逐层蒸馏示例请参见获取更多样例。
接口调用流程
调用流程如图1所示。
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现,逐层蒸馏特性主要分为4个部分,创建蒸馏配置、创建蒸馏模型、蒸馏量化模型和保存量化模型,详情如下:
- 创建蒸馏配置,用户首先构造PyTorch的原始模型,然后使用create_distill_config接口,将用户自定义的蒸馏配置与AMCT算法定义的蒸馏配置相结合,输出网络每一层的蒸馏配置信息。
- 创建蒸馏模型,调用create_distill_model接口对原始模型进行修改,根据蒸馏配置信息生成一个待蒸馏的量化模型。
- 蒸馏模型,调用distill接口,根据用户配置的模型推理和优化方法,结合蒸馏配置信息,对网络进行分块蒸馏,得到蒸馏后的性能优化的量化模型。
- 保存量化模型,最后调用save_distill_model接口保存已蒸馏优化的量化模型,包括可在ONNX执行框架ONNX Runtime环境中进行精度仿真的模型文件和可部署在昇腾AI处理器的模型文件。
调用示例
- 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
- 如下示例调用AMCT的部分,函数入参请根据实际情况进行调整。
- 导入AMCT包,并通过安装后处理中的环境变量设置日志级别。
1
import amct_pytorch as amct
- (可选,由用户补充处理)在PyTorch原始环境中验证推理脚本及环境。
建议使用原始待量化的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。
推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
1
user_do_inference_torch(ori_model, test_data, test_iterations)
- 调用AMCT,蒸馏模型。
- 调用create_distill_config接口生成蒸馏配置。输出的蒸馏配置中包含用户自定义的和自动查找到的蒸馏结构。
1 2 3 4 5 6
config_file = './tmp/config.json' simple_cfg = './distill.cfg' amct.create_distill_config(config_file=config_file, model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg)
- 调用create_distill_model接口创建蒸馏模型。
对待蒸馏的浮点模型进行量化,将浮点模型中的待压缩算子替换为CANN量化算子。
1 2 3 4
compress_model = amct.create_distill_model( config_file=config_file, model=ori_model, input_data=ori_model_input_data)
- 逐层蒸馏。
调用distill接口进行逐层蒸馏。针对配置中的蒸馏结构进行蒸馏。
1 2 3 4 5 6 7 8 9 10
distill_model = amct.distill( model=ori_model, compress_model config_file=config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer)
- 保存蒸馏模型。
调用save_distill_model接口,插入AscendQuant、AscendDequant等算子,保存为蒸馏模型。
1 2 3 4 5 6 7 8 9
amct.save_distill_model( model, "./model/distilled" input_data, record_file="./results/records.txt" input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}})
- 调用create_distill_config接口生成蒸馏配置。输出的蒸馏配置中包含用户自定义的和自动查找到的蒸馏结构。
- (可选,由用户补充处理)基于ONNX Runtime的环境,使用蒸馏后模型(quant_model)在测试集(test_data)上做推理,测试蒸馏后仿真模型的精度。
使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。
1 2
quant_model = './results/user_model_fake_quant_model.onnx' user_do_inference_onnx(quant_model, test_data, test_iterations)
父主题: 扩展更多特性