下载
中文
注册

auto_channel_prune_search

功能说明

自动通道稀疏接口,根据用户模型来计算各通道的稀疏敏感度(影响精度)以及稀疏收益(影响性能),然后搜索策略依据该输入来搜索最优的逐层通道稀疏率,以平衡精度和性能。最终输出一个配置文件。

约束说明

函数原型

auto_channel_prune_search(model, config, input_data, output_cfg, sensitivity, search_alg)

参数说明

参数名

输入/返回值

含义

使用限制

model

输入

待稀疏的PyTorch模型。

数据类型:torch.nn.Module

config

输入

自动通道稀疏配置文件路径。

基于basic_info.proto文件中的AutoChannelPruneConfig生成的简易配置文件,basic_info.proto件所在路径为:AMCT安装目录/amct_pytorch/proto/basic_info.proto。

basic_info.proto文件参数解释以及生成的自动通道稀疏搜索配置文件样例请参见自动(混合精度或通道稀疏)搜索简易配置文件说明

数据类型:string

input_data

输入

用户提供获取输入数据(含label)

数据类型:list[data,label]

列表元素数据类型为torch.tensor

output_cfg

输入&返回值

输出的最终的通道稀疏配置文件路径。

数据类型:string

sensitivity

输入

敏感度计算方法。

数据类型:string或

SensitivityBase的子类,string为amct已有的方法,目前可选为'TaylorLossSensitivity';SensitivityBase的子类实例化,可由用户来继承定义

search_alg

输入

待稀疏的通道搜索方法。

数据类型:string或

SearchChannelBase的子类,string为amct已有的方法,目前可选为'GreedySearch';SearchChannelBase的子类实例化,可由用户来继承定义

返回值说明

无。

函数输出

自动通道稀疏配置文件。

该文件需要传给通道稀疏接口完成后续的业务。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import amct_pytorch as amct
#构造输入数据input_data
input_data = torch.randn(input_shape)         
model.eval()        
output = model.forward(input_data)        
labels = torch.randn(output.size())        
data = [input_data,labels]

amct.auto_channel_prune_search(
     model=model,
     config='./tmp/sample.cfg',
     input_data=data,  
     output_cfg='./tmp/output.cfg', 
     sensitivity='TaylorLossSensitivity', 
     search_alg='GreedySearch')