量化流程
本节详细介绍量化感知训练支持的量化层,接口调用流程和示例。
量化感知训练当前仅支持对FP32数据类型的网络模型进行量化。量化感知训练支持量化的层以及约束如下,量化示例请参见获取更多样例>resnet_v1_50。
支持的层类型 |
约束 |
---|---|
MatMul |
transpose_a=False, transpose_b=False,adjoint_a=False,adjoint_b=False |
Conv2D |
由于硬件约束,原始模型中输入通道数Cin<=16时不建议进行量化感知训练,否则可能会导致量化后的部署模型推理时精度下降 |
DepthwiseConv2dNative |
由于硬件约束,原始模型中输入通道数Cin<=16时不建议进行量化感知训练,否则可能会导致量化后的部署模型推理时精度下降 |
Conv2DBackpropInput |
|
AvgPool |
- |
接口调用流程
量化感知训练接口调用流程如图1所示。
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现,用户在TensorFlow原始网络推理的代码中导入库,并在特定的位置调用相应API,即可实现量化功能。
简要流程如下:
- 用户构造训练模式的图结构,然后调用create_quant_retrain_config接口生成量化配置文件。
- 调用create_quant_retrain_model图修改接口,根据量化配置文件对训练的图进行量化前的图结构修改:插入数据量化,权重量化相关算子等。
- 训练模型,将参数保存为checkpoint文件。
- 调用create_quant_retrain_model接口,对推理模式的图进行修改:插入数据量化,权重量化相关算子等。
- 恢复训练参数,加载ckpt文件,推理量化的输出节点,将量化因子写入record文件,并将推理图固化为pb模型。
- 调用save_quant_retrain_model接口,插入AscendQuant/AscendDequant等量化算子,保存量化模型。
调用示例
- 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
- 调用AMCT的部分,函数入参可以根据实际情况进行调整。量化感知训练基于用户的训练过程,请确保已经有基于TensorFlow环境进行训练的脚本,并且训练后的精度正常。
- 导入AMCT包,设置日志级别。
1 2
import amct_tensorflow as amct amct.set_logging_level(print_level='info', save_level='info')
- (可选,由用户补充处理)创建图并读取训练好的参数,在TensorFlow环境下推理,验证环境、推理脚本是否正常。
推荐执行该步骤,以确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
1
user_test_evaluate_model(evaluate_model, test_data)
- (由用户补充处理)创建训练图。
1
train_graph = user_load_train_graph()
- 调用AMCT,执行带量化参数的训练流程。
- 生成量化配置。
用户基于构造的训练模式的图结构(BN的is_training参数为True),调用create_quant_retrain_config接口生成量化配置文件(对应图1中的序号1)。
1 2 3 4 5
config_file = './tmp/config.json' simple_cfg = './retrain.cfg' amct.create_quant_retrain_config(config_file=config_file, graph=train_graph, config_defination=simple_cfg)
- 修改训练模式的图。
调用量化图修改接口create_quant_retrain_model,根据量化配置文件对训练的图进行量化前的图结构修改:在图中插入数据量化、权重量化等相关算子,用于计算量化相关参数(对应图1中的序号2)。
1 2 3 4
record_file = './tmp/record.txt' retrain_ops = amct.create_quant_retrain_model(graph=train_graph, config_file=config_file, record_file=record_file)
- (由用户补充处理)使用修改后的图,创建反向梯度,在训练集上做训练,训练量化因子。
- 使用修改后的图,调用自适应学习率优化器(RMSPropOptimizer)建立反向梯度图。该步骤需要在4.b后执行。
1 2 3
optimizer = tf.compat.v1.train.RMSPropOptimizer( ARGS.learning_rate, momentum=ARGS.momentum) train_op = optimizer.minimize(loss)
- 创建会话,进行模型的训练,并将训练后的参数保存为checkpoint文件(对应图1中的序号3,4)。
注意:从训练好的checkpoint恢复模型参数后再训练;训练中保存的参数应该包括量化因子。
1 2 3 4 5
with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(outputs) #将训练后的参数保存为checkpoint文件 saver_save.save(sess, retrain_ckpt, global_step=0)
- 使用修改后的图,调用自适应学习率优化器(RMSPropOptimizer)建立反向梯度图。该步骤需要在4.b后执行。
- 生成量化配置。
- (由用户补充处理)创建推理图。
1
test_graph = user_load_test_graph()
- 调用AMCT,实现量化感知训练。
- 修改推理模式的图。
用户基于构造的推理模式的图结构(BN的is_training参数为False),调用量化图修改接口create_quant_retrain_model,根据量化配置文件对推理的图进行量化前的图结构修改:在图中插入数据量化、权重量化等相关算子(对应图1中的序号5)。
1 2 3 4
record_file = './tmp/record.txt' retrain_ops = amct.create_quant_retrain_model(graph=train_graph, config_file=config_file, record_file=record_file)
- (由用户补充处理)创建会话,恢复训练参数,推理量化的输出节点(retrain_ops[-1]),将量化因子写入record文件,并将推理图固化为pb模型(对应图1中的序号6,7)。
说明:推理和恢复的参数要在同一session中,推理执行的是retrain_ops[-1]的输出tensor;推理图固化为pb模型时,包含训练好的参数。
1 2 3 4 5 6 7 8 9 10 11 12 13
variables_to_restore = tf.compat.v1.global_variables() saver_restore = tf.compat.v1.train.Saver(variables_to_restore) with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) #恢复训练参数 saver_restore.restore(sess, retrain_ckpt) #推理量化的输出节点(retrain_ops[-1]),将量化因子写入record文件 sess.run(retrain_ops[-1]) #固化pb模型 constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants( sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs]) with tf.io.gfile.GFile(frozen_quant_eval_pb, 'wb') as f: f.write(constant_graph.SerializeToString())
- 保存量化模型。
根据量化因子以及pb文件,调用save_quant_retrain_model接口,插入AscendQuant、AscendDequant等算子,保存为量化模型(对应图1中的序号8)。
1 2 3 4 5
quant_model_path = './result/user_model' amct.save_quant_retrain_model(pb_model=trained_pb, outputs=user_model_outputs, record_file=record_file, save_path=quant_model_path)
- 修改推理模式的图。
- (可选,由用户补充处理)使用量化后模型user_model_quantized.pb和测试集,在TensorFlow环境下推理,测试量化后的仿真模型精度。使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。
1 2
quant_model = './results/user_model_quantized.pb' user_do_inference(quant_model, test_data)
父主题: 量化感知训练