TensorFlow网络模型由于AMCT导致输出节点改变,如何通过修改量化脚本进行后续的量化动作
问题描述
使用AMCT调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于插入了searchN层导致尾层输出节点发生改变。该场景下,需要用户根据提示信息,将推理时的输出节点替换为新的输出节点的名称;AMCT量化过程中的日志信息给出了网络输出节点变化前后的节点名称,需要用户根据提示信息, 自行修改量化脚本。
进行图修改时,导致尾层输出节点发生改变的场景有如下几种情况:
- 场景1:网络模型的尾层为ADD/ADDV2,且ADD加数的维度为一维,满足biasadd的功能
图1 尾层为ADD/ADDV2
- 尾层为Add时,图修改时的提示信息为:
2020-09-01 09:31:04,896 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Add:0' <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0' 2020-09-01 09:31:04,979 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
- 尾层为AddV2时,图修改时的提示信息为:
2020-09-01 09:32:42,281 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'add:0' <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0' 2020-09-01 09:32:42,362 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
- 尾层为Add时,图修改时的提示信息为:
- 场景2:网络的尾层为BiasAdd,且其前面为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul中的一种
图2 网络的尾层为BiasAdd,且其前面为Conv2D
该场景下,图修改时的提示信息为:
2020-09-01 09:39:26,130 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'BiasAdd:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
- 场景3:网络的尾层为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul、AvgPool中的一种
图3 网络的尾层为Conv2D
该场景下,图修改时的提示信息为:
2020-09-01 09:40:28,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Conv2D:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
- 场景4:网络的尾层为FusedBatchNorm\FusedBatchNormV2\FusedBatchNormV3中的一种,且其前面为Conv2D+(BiasAdd)或DepthwiseConv2dNative+(BiasAdd)
图4 尾层为FusedBatchNormV3
该场景下,图修改时的提示信息为:
2020-09-01 09:44:08,637 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/FusedBatchNormV3:0' <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0' 2020-09-01 09:44:08,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
- 场景5:网络的尾层为BN小算子结构,且其输入的为4维数据
图5 尾层为BN小算子结构
该场景下,图修改时的提示信息为:
2020-09-01 09:46:46,426 - WARNING - [AMCT]:[replace_bn_pass]: Replace BN at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/add_1:0' <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0' 2020-09-01 09:46:46,439 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0' <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0' 2020-09-01 09:46:46,518 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0' //网络输出节点变化前的节点名称 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0' //网络输出节点变化后的节点名称
脚本修改
如果调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于在网络最后插入了searchN层导致尾层输出节点发生改变,需要用户根据日志信息,修改量化脚本,将网络推理过程中的输出节点替换为新的节点名称,修改方法如下:
修改前的量化脚本(如下脚本只是样例,请以实际量化的模型为准):
import tensorflow as tf import amct_tensorflow as amct def load_pb(model_name): with tf.gfile.GFile(model_name, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') def main(): # 网络pb文件的名字 model_name = './pb_model/case_1_1.pb' # 网络量化推理输出节点的名字 infer_output_name = 'Add:0' # 网络保存量化模型的输出节点的名字 save_output_name = 'Add:0' # 载入网络的pb文件 load_pb(model_name) # 获取网络的图结构 graph = tf.get_default_graph() # 生成量化配置文件 amct.create_quant_config( config_file='./configs/config.json', graph=graph) # 插入量化相关算子 amct.quantize_model( graph=graph, config_file='./configs/config.json', record_file='./configs/record_scale_offset.txt') # 执行网络的推理过程 with tf.Session() as sess: output_tensor = graph.get_tensor_by_name(infer_output_name) sess.run(tf.global_variables_initializer()) sess.run(output_tensor) # 保存量化后的pb模型文件 amct.save_model( pb_model=model_name, outputs=[save_output_name[:-2]], record_file='./configs/record_scale_offset.txt', save_path='./pb_model/case_1_1') if __name__ == '__main__': main()
修改后的量化脚本:
import tensorflow as tf import amct_tensorflow as amct def load_pb(model_name): with tf.gfile.GFile(model_name, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') def main(): # 网络pb文件的名字 model_name = './pb_model/case_1_1.pb' # 网络量化推理输出节点的名字,需要替换为日志打印的网络输出节点变化后的节点名称 infer_output_name = 'search_n_quant/search_n_quant_SEARCHN/Identity:0' # 网络保存量化模型的输出节点的名字 save_output_name = 'Add:0' # 载入网络的pb文件 load_pb(model_name) # 获取网络的图结构 graph = tf.get_default_graph() # 生成量化配置文件 amct.create_quant_config( config_file='./configs/config.json', graph=graph) # 插入量化相关算子 amct.quantize_model( graph=graph, config_file='./configs/config.json', record_file='./configs/record_scale_offset.txt') # 执行网络的推理过程 with tf.Session() as sess: output_tensor = graph.get_tensor_by_name(infer_output_name) sess.run(tf.global_variables_initializer()) sess.run(output_tensor) # 保存量化后的pb模型文件 amct.save_model( pb_model=model_name, outputs=[save_output_name[:-2]], record_file='./configs/record_scale_offset.txt', save_path='./pb_model/case_1_1') if __name__ == '__main__': main()
父主题: FAQ