量化感知训练支持量化的层以及约束如下,量化过程请参见获取更多样例>ResNet50。
支持的层类型 |
约束 |
---|---|
nn.Dense:全连接层 |
- |
nn.Conv2d:卷积层 |
dilation为1、group为1、filter维度为4,或者dilation为1、group不为1,group=channel、filter维度为4 |
量化感知训练实现原理如图1所示。
如果训练过程中断,则可以基于保存的checkpoint和量化配置文件,重新调用restore_quant_retrain_model接口,输出修改后的retrain network,继续进行量化感知的训练,训练结束后然后调用save_quant_retrain_model接口保存量化后的模型,导出air格式的量化模型文件。
本示例演示如何使用AMCT进行量化感知训练。
import amct_mindspore as amct amct.set_logging_level(print_level='info', save_level='info')
user_do_calibration()
config_file = './tmp/config.json' fake_input_data = np.random.uniform(0.0, 1.0, size=ori_network_input_shape).astype(np.float32) amct.create_quant_retrain_config(config_file=config_file, network=ori_network, input_data=fake_input_data)
retrain_network = amct.create_quant_retrain_model(config_file=config_file, network=ori_network, initializer=initializer, input_data=fake_input_data)
该过程中如果提示数据类型不一致等类似信息,则请参见量化感知训练场景下,出现数据类型不一致的解决方法处理。
optimizer = user_create_optimizer(quant_retrain_model)
quant_pth = './ckpt/user_model' user_train_model(optimizer, quant_retrain_model, train_data)
quant_model_path = './result/user_model' amct.save_quant_retrain_model(config_file=config_file, file_name=quant_model_path, network=retrain_network, input_data=fake_input_data)
如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:
import amct_mindspore as amct amct.set_logging_level(print_level='info', save_level='info')
ori_network = user_create_network()
param_dict = load_checkpoint(user_network_checkpoint) load_param_into_net(ori_network, param_dict)
ori_network.set_train(True)
config_file = './tmp/config.json' fake_input_data = np.random.uniform(0.0, 1.0, size=ori_network_input_shape).astype(np.float32) checkpoint_path = './ckpt/user_model-newest.ckpt' retrain_network = amct.restore_quant_retrain_model(config_file=config_file, network=ori_network, checkpoint_path=checkpoint_path, input_data=fake_input_data)
optimizer = user_create_optimizer(quant_retrain_model)
quant_pth = './ckpt/user_model' user_train_model(optimizer, quant_retrain_model, train_data)
quant_model_path = './result/user_model' amct.save_quant_retrain_model(config_file=config_file, file_name=quant_model_path, network=retrain_network, input_data=fake_input_data)