文档
注册

TensorFlow网络模型由于昇腾模型压缩工具导致输出节点改变,如何通过修改量化脚本进行后续的量化动作

问题描述

使用昇腾模型压缩工具调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于插入了searchN层导致尾层输出节点发生改变。该场景下,需要用户根据提示信息,将推理时的输出节点替换为新的输出节点的名称;昇腾模型压缩工具量化过程中的日志信息给出了网络输出节点变化前后的节点名称,需要用户根据提示信息, 自行修改量化脚本。

进行图修改时,导致尾层输出节点发生改变的场景有如下几种情况:

  • 场景1:网络模型的尾层为ADD/ADDV2,且ADD加数的维度为一维,满足biasadd的功能
    图1 尾层为ADD/ADDV2
    • 尾层为Add时,图修改时的提示信息为:
      2020-09-01 09:31:04,896 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Add:0'
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0'
      2020-09-01 09:31:04,979 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0'   //网络输出节点变化前的节点名称
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'  //网络输出节点变化后的节点名称
    • 尾层为AddV2时,图修改时的提示信息为:
      2020-09-01 09:32:42,281 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'add:0'
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0'
      2020-09-01 09:32:42,362 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0'   //网络输出节点变化前的节点名称
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景2:网络的尾层为BiasAdd,且其前面为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul中的一种
    图2 网络的尾层为BiasAdd,且其前面为Conv2D

    该场景下,图修改时的提示信息为:

    2020-09-01 09:39:26,130 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'BiasAdd:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景3:网络的尾层为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul、AvgPool中的一种
    图3 网络的尾层为Conv2D

    该场景下,图修改时的提示信息为:

    2020-09-01 09:40:28,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Conv2D:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'  //网络输出节点变化后的节点名称
  • 场景4:网络的尾层为FusedBatchNorm\FusedBatchNormV2\FusedBatchNormV3中的一种,且其前面为Conv2D+(BiasAdd)或DepthwiseConv2dNative+(BiasAdd)
    图4 尾层为FusedBatchNormV3

    该场景下,图修改时的提示信息为:

    2020-09-01 09:44:08,637 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/FusedBatchNormV3:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0'
    2020-09-01 09:44:08,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景5:网络的尾层为BN小算子结构,且其输入的为4维数据
    图5 尾层为BN小算子结构

    该场景下,图修改时的提示信息为:

    2020-09-01 09:46:46,426 - WARNING - [AMCT]:[replace_bn_pass]: Replace BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/add_1:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0'
    2020-09-01 09:46:46,439 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0'
    2020-09-01 09:46:46,518 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称

脚本修改

如果调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于在网络最后插入了searchN层导致尾层输出节点发生改变,需要用户根据日志信息,修改量化脚本,将网络推理过程中的输出节点替换为新的节点名称,修改方法如下:

修改前的量化脚本(如下脚本只是样例,请以实际量化的模型为准):

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字
    infer_output_name = 'Add:0'
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()

修改后的量化脚本:

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字,需要替换为日志打印的网络输出节点变化后的节点名称
    infer_output_name = 'search_n_quant/search_n_quant_SEARCHN/Identity:0' 
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词