下载
中文
注册

auto_channel_prune_search

功能说明

自动通道稀疏搜索接口,根据用户模型来计算各通道的稀疏敏感度(影响精度)以及稀疏收益(影响性能),然后搜索策略依据该输入来搜索最优的逐层通道稀疏率,以平衡精度和性能。最终输出一个配置文件。

约束说明

无。

函数原型

auto_channel_prune_search(graph, output_nodes, config, input_data, output_cfg, sensitivity, search_alg)

参数说明

参数名

输入/返回值

含义

使用限制

graph

输入

待稀疏的、包含自动微分的tf.Graph训练图。

数据类型:tf.Graph

output_nodes

输入

模型输出节点的名称。

数据类型:list,列表中元素类型为string

config

输入

自动通道稀疏搜索配置文件路径。

基于basic_info.proto文件中的AutoChannelPruneConfig生成的简易配置文件,basic_info.proto件所在路径为:AMCT安装目录/amct_tensorflow/proto/basic_info.proto。

basic_info.proto文件参数解释以及生成的自动通道稀疏搜索配置文件样例请参见自动通道稀疏搜索简易配置文件说明

数据类型:string

input_data

输入

用户提供的校准数据。

数据类型:list,内容为对应的feed_dict数据

output_cfg

输入&返回值

输出的最终的通道稀疏配置文件路径。

数据类型:string

sensitivity

输入

敏感度计算方法。

数据类型:string或

SensitivityBase的子类,string为AMCT已有的方法,目前可选为'TaylorLossSensitivity';子类为SensitivityBase的子类的实例化,可由用户来继承定义。默认为'TaylorLossSensitivity'。

search_alg

输入

待稀疏的通道搜索方法。

数据类型:string或

SearchChannelBase的子类,string为AMCT已有的方法,目前可选为'GreedySearch';子类为SearchChannelBase的子类的实例化,可由用户来继承定义。默认为'GreedySearch'。

返回值说明

无。

函数输出

自动通道稀疏配置文件。

该文件需要传给通道稀疏接口完成后续的业务。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_tensorflow as amct
#构造feed_dict数据
input_data = np.random.uniform(-10, 10, (2, 14, 14, 64)).astype(np.str_)
feed_dict = [{'input:0': input_data}]

amct.auto_channel_prune_search(
     graph=graph, 
     output_nodes=[operation_name_1, operation_name_2]
     config='./tmp/sample.cfg', 
     input_data=feed_dict,  
     output_cfg='./tmp/output.cfg', 
     sensitivity='TaylorLossSensitivity', 
     search_alg='GreedySearch')