auto_channel_prune_search
功能说明
自动通道稀疏搜索接口,根据用户模型来计算各通道的稀疏敏感度(影响精度)以及稀疏收益(影响性能),然后搜索策略依据该输入来搜索最优的逐层通道稀疏率,以平衡精度和性能。最终输出一个配置文件。
约束说明
无。
函数原型
auto_channel_prune_search(graph, output_nodes, config, input_data, output_cfg, sensitivity, search_alg)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
graph |
输入 |
待稀疏的、包含自动微分的tf.Graph训练图。 |
数据类型:tf.Graph |
output_nodes |
输入 |
模型输出节点的名称。 |
数据类型:list,列表中元素类型为string |
config |
输入 |
自动通道稀疏搜索配置文件路径。 基于basic_info.proto文件中的AutoChannelPruneConfig生成的简易配置文件,basic_info.proto件所在路径为:AMCT安装目录/amct_tensorflow/proto/basic_info.proto。 basic_info.proto文件参数解释以及生成的自动通道稀疏搜索配置文件样例请参见自动通道稀疏搜索简易配置文件说明 |
数据类型:string |
input_data |
输入 |
用户提供的校准数据。 |
数据类型:list,内容为对应的feed_dict数据 |
output_cfg |
输入&返回值 |
输出的最终的通道稀疏配置文件路径。 |
数据类型:string |
sensitivity |
输入 |
敏感度计算方法。 |
数据类型:string或 SensitivityBase的子类,string为AMCT已有的方法,目前可选为'TaylorLossSensitivity';子类为SensitivityBase的子类的实例化,可由用户来继承定义。默认为'TaylorLossSensitivity'。 |
search_alg |
输入 |
待稀疏的通道搜索方法。 |
数据类型:string或 SearchChannelBase的子类,string为AMCT已有的方法,目前可选为'GreedySearch';子类为SearchChannelBase的子类的实例化,可由用户来继承定义。默认为'GreedySearch'。 |
返回值说明
无。
函数输出
自动通道稀疏配置文件。
该文件需要传给通道稀疏接口完成后续的业务。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 | import amct_tensorflow as amct #构造feed_dict数据 input_data = np.random.uniform(-10, 10, (2, 14, 14, 64)).astype(np.str_) feed_dict = [{'input:0': input_data}] amct.auto_channel_prune_search( graph=graph, output_nodes=[operation_name_1, operation_name_2], config='./tmp/sample.cfg', input_data=feed_dict, output_cfg='./tmp/output.cfg', sensitivity='TaylorLossSensitivity', search_alg='GreedySearch') |