下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

TensorFlow网络模型由于AMCT导致输出节点改变,如何通过修改量化脚本进行后续的量化动作

问题描述

使用AMCT调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于插入了searchN层导致尾层输出节点发生改变。该场景下,需要用户根据提示信息,将推理时的输出节点替换为新的输出节点的名称;AMCT量化过程中的日志信息给出了网络输出节点变化前后的节点名称,需要用户根据提示信息, 自行修改量化脚本。

进行图修改时,导致尾层输出节点发生改变的场景有如下几种情况:

  • 场景1:网络模型的尾层为ADD/ADDV2,且ADD加数的维度为一维,满足biasadd的功能
    图1 尾层为ADD/ADDV2
    • 尾层为Add时,图修改时的提示信息为:
      2020-09-01 09:31:04,896 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Add:0'
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0'
      2020-09-01 09:31:04,979 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0'   //网络输出节点变化前的节点名称
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'  //网络输出节点变化后的节点名称
    • 尾层为AddV2时,图修改时的提示信息为:
      2020-09-01 09:32:42,281 - WARNING - [AMCT]:[replace_add_pass]: Replace ADD at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'add:0'
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add/BiasAdd:0'
      2020-09-01 09:32:42,362 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
      >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add/BiasAdd:0'   //网络输出节点变化前的节点名称
      <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景2:网络的尾层为BiasAdd,且其前面为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul中的一种
    图2 网络的尾层为BiasAdd,且其前面为Conv2D

    该场景下,图修改时的提示信息为:

    2020-09-01 09:39:26,130 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'BiasAdd:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景3:网络的尾层为Conv2D、DepthwiseConv2dNative、Conv2DBackpropInput、MatMul、AvgPool中的一种
    图3 网络的尾层为Conv2D

    该场景下,图修改时的提示信息为:

    2020-09-01 09:40:28,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'Conv2D:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'  //网络输出节点变化后的节点名称
  • 场景4:网络的尾层为FusedBatchNorm\FusedBatchNormV2\FusedBatchNormV3中的一种,且其前面为Conv2D+(BiasAdd)或DepthwiseConv2dNative+(BiasAdd)
    图4 尾层为FusedBatchNormV3

    该场景下,图修改时的提示信息为:

    2020-09-01 09:44:08,637 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/FusedBatchNormV3:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0'
    2020-09-01 09:44:08,717 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称
  • 场景5:网络的尾层为BN小算子结构,且其输入的为4维数据
    图5 尾层为BN小算子结构

    该场景下,图修改时的提示信息为:

    2020-09-01 09:46:46,426 - WARNING - [AMCT]:[replace_bn_pass]: Replace BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/add_1:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0'
    2020-09-01 09:46:46,439 - WARNING - [AMCT]:[conv_bn_fusion_pass]: Fused BN at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'batch_normalization/batchnorm/bn_replace/batch_normalization/FusedBatchNormV3:0'
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'bias_add:0'
    2020-09-01 09:46:46,518 - WARNING - [AMCT]:[quantize_model]: Insert searchN operator at the end of the network! You need to replace the old output node by the new output node in inference process! 
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>The name of the old output node is 'bias_add:0'   //网络输出节点变化前的节点名称
    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<The name of the new output node is 'search_n_quant/search_n_quant_SEARCHN/Identity:0'   //网络输出节点变化后的节点名称

脚本修改

如果调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于在网络最后插入了searchN层导致尾层输出节点发生改变,需要用户根据日志信息,修改量化脚本,将网络推理过程中的输出节点替换为新的节点名称,修改方法如下:

修改前的量化脚本(如下脚本只是样例,请以实际量化的模型为准):

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字
    infer_output_name = 'Add:0'
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()

修改后的量化脚本:

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字,需要替换为日志打印的网络输出节点变化后的节点名称
    infer_output_name = 'search_n_quant/search_n_quant_SEARCHN/Identity:0' 
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词