接口调用流程
接口调用流程如图1所示,蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现。稀疏示例请参见获取更多样例。
用户准备好TensorFlow的训练模型、自动通道稀疏搜索配置文件和校准数据,调用auto_channel_prune_search,根据压缩率、各通道的稀疏敏感度以及稀疏收益,执行自动通道稀疏搜索,得到可用作通道稀疏的简易配置文件。其中,sensitivity模块与search_alg模块用户可以自定义或者使用接口内部默认方法。
- sensitivity模块实现计算各通道的稀疏敏感度。
- search_alg模块实现了基于通道敏感度与通道稀疏收益进行稀疏通道搜索的过程。
图1 调用流程
调用示例
本示例演示了使用AMCT进行自动通道稀疏搜索的流程,该过程需要用户传入tensorflow训练模式的图与校准数据,用户可选择自定义实现sensitivity模块与search_alg模块。
- 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
- 如下示例调用AMCT的部分,函数入参请根据实际情况进行调整。
- 导入AMCT包,设置日志级别。
|
import amct_tensorflow as amct
amct.set_logging_level(print_level="info", save_level="info")
|
- (可选,由用户补充处理)实现sensitivity模块,获取各个layer各通道的敏感度,为后续的搜索算法提供数据。可参考系统默认sensitivity模块:AMCT安装目录
amct_tensorflow/interface/auto_channel_prune_search.py下的TaylorLossSensitivity方法,简要流程如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
from amct.common.auto_prune.sensitivity_base import SensitivityBase
class Sensitivity(SensitivityBase)
def __init__(self)
pass
def setup_initialization(self, graph_tuple, input_data, test_iteration, output_nodes=None):
# 必要的初始化
# graph_tuple (graph, graph_info)
pass
def get_sensitivity(self, search_records):
# 获取敏感度方法,计算后写到record中
pass
|
- (可选,由用户补充处理)实现搜索算法search_alg模块,需要用户实现channel_prune_search内部回调接口,根据通道敏感度与通道稀疏收益进行稀疏通道搜索。可参考系统默认search_alg模块:AMCT安装目录/amct_tensorflow/common/auto_prune/search_channel_base.py文件中GreedySearch方法。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
from amct.common.auto_prune.search_channel_base import SearchChannelBase
class Search(SearchChannelBase)
def __init__(self)
# 初始化
pass
def channel_prune_search(self, graph_info, search_records, prune_config):
"""
输入:
graph_info: dict,包含图中各算子的通道数量与比特复杂度信息,可用于计算压缩率
search_records: protobuf对象,包含待搜索的可稀疏层
prune_config: 三元组-目标压缩率(float)、昇腾亲和优化开关(bool)、单层最大稀疏率(float)
输出:
dict,key为待搜索的可稀疏层层名,value为01组成的list,对应该通道是否应稀疏
"""
pass
|
- (可选,由用户补充处理)创建图并读取训练好的参数,在TensorFlow环境下推理,验证环境、推理脚本是否正常。
推荐执行该步骤,以确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
|
user_test_evaluate_model(evaluate_model, test_data)
|
- (由用户补充处理)创建训练图,构造校准数据。
|
train_graph = user_load_train_graph()
input_data = []
for _ in range(test_iteration):
input_data.append(user_load_feed_dict())
|
- 调用AMCT,进行自动通道稀疏搜索。
|
output_prune_cfg = './prune.cfg'
amct.auto_channel_prune_search(
graph=train_graph,
output_nodes=user_model_outputs,
config=cfg_file,
input_data=input_data,
output_cfg=output_prune_cfg,
sensitivity='TaylorLossSensitivity',
search_alg='GreedySearch')
|