基础稀疏
AMCT支持基于重训练的4选2结构化稀疏特性。该特性支持的层以及约束如下:
优化方式 |
支持的层类型 |
约束 |
---|---|---|
4选2结构化稀疏 |
torch.nn.Linear:全连接层 |
复用层(共用weight和bias参数)不支持稀疏。 |
torch.nn.Conv2d:卷积层 |
复用层(共用weight和bias参数)不支持稀疏。 depthwise只能被动稀疏(groups=in_channels),不能主动稀疏。 |
|
torch:ConvTranspose2d:反卷积层 |
- |
接口调用流程
4选2结构化稀疏功能接口调用流程如图1所示。
图1 4选2结构化稀疏接口调用流程
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现:
- 用户首先构造PyTorch的原始模型,调用create_prune_retrain_model接口对原始模型进行修改,把待稀疏的算子替换成插入了4选2结构化稀疏的算子。
- 对修改后的模型进行训练,直至精度满足要求;如果训练过程中断,则可基于原始模型和记录稀疏信息的文件,重新调用restore_prune_retrain_model接口稀疏模型,继续进行量化感知的训练,直至精度满足要求。
- 根据用户最终的重训练好的4选2结构化稀疏模型,调用save_prune_retrain_model接口,还原替换的算子并对weight进行稀疏,生成最终ONNX仿真模型以及部署模型。
调用示例
- 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
- 调用AMCT的部分,函数入参可以根据实际情况进行调整。稀疏基于用户的训练过程,请确保已经有基于PyTorch环境进行训练的脚本,并且训练后的精度正常。
- (可选,由用户补充处理)建议使用原始待量化的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。
推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
1 2 3
ori_model.load() # 测试模型 user_test_model(ori_model, test_data, test_iterations)
- 调用AMCT,执行4选2结构化稀疏流程。
- 对原始模型进行修改,把待稀疏的算子替换成插入了4选2结构化稀疏的算子。
实现该步骤前,应先恢复训练好的参数,如1中的ori_model.load()。
1 2 3 4 5 6
simple_cfg = './retrain.cfg' record_file = './tmp/record.txt' prune_retrain_model = amct.create_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg, record_file=record_file)
- (由用户补充处理)使用修改后的图,创建反向梯度,在训练集上做训练。
- 调用save_prune_retrain_model接口,保存模型,还原替换的算子并对weight进行结构化稀疏,生成最终ONNX仿真模型以及部署模型。
1 2 3 4
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
- 对原始模型进行修改,把待稀疏的算子替换成插入了4选2结构化稀疏的算子。
- (可选,由用户补充处理)基于ONNX Runtime的环境,使用稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。
使用稀疏后仿真模型精度与1中的原始精度做对比,可以观察4选2结构化稀疏对精度的影响。
1 2
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:
- 准备原始模型。
1
ori_model= user_create_model()
- 调用AMCT,恢复量化训练流程。
- 修改模型,把待稀疏的算子替换成插入了4选2结构化稀疏的算子,保存为新的prune_model。
1 2 3 4 5 6 7 8 9 10 11 12
model = ori_model input_data = ori_model_input_data record_file = './tmp/record.txt' config_defination = './prune_cfg.cfg' save_pth_path = /your/path/to/save/tmp.pth model.load_state_dict(torch.load(state_dict_path)) prune_retrain_model = amct.restore_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, record_file=record_file, config_defination='./prune_cfg.cfg', save_pth_path=/your/path/to/save/tmp.pth, 'state_dict')
- (由用户补充处理)使用修改后的模型,恢复断点,创建反向梯度,在训练集上做训练。
- 调用save_prune_retrain_model接口,保存模型,还原替换的算子并对weight进行结构化稀疏,生成最终ONNX仿真模型以及部署模型。
1 2 3 4
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
- 修改模型,把待稀疏的算子替换成插入了4选2结构化稀疏的算子,保存为新的prune_model。
- (可选,由用户补充处理)基于ONNX Runtime的环境,使用稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。
使用稀疏后仿真模型精度与1中的原始精度做对比,可以观察4选2结构化稀疏对精度的影响。
1 2
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
父主题: 4选2结构化稀疏