下载
中文
注册

auto_mixed_precision_search

功能说明

根据原始模型和自动混合精度搜索简易配置文件,自动搜索模型的最优混合量化位宽配置,生成量化感知训练简易配置文件。

函数原型

auto_mixed_precision_search(model, input_data, config, save_dir, evaluator, sensitivity='MseSimilarity')

参数说明

参数名

输入/返回值

含义

使用限制

model

输入

PyTorch模型。

数据类型:torch.nn.module

input_data

输入

模型的输入数据。一个torch.tensor会被等价为tuple([torch.tensor])

数据类型:tuple

config

输入

自动混合精度搜索简易配置文件。

基于basic_info.proto文件中的AutoMixedPrecisionConfig生成的简易配置文件,basic_info.proto件所在路径为:AMCT安装目录/amct_pytorch/proto/basic_info.proto。

basic_info.proto文件参数解释以及生成的自动混合精度搜索配置文件样例请参见自动(混合精度或通道稀疏)搜索简易配置文件说明

数据类型:string

save_dir

输入

用于保存生成的QAT简易配置文件的路径。

该路径需要包含模型名前缀。例如:/your/save/path/model_name

数据类型:string

evaluator

输入

进行校准和评估精度的python实例。

数据类型:AutoCalibrationEvaluatorBase派生类。

使用约束:需实现evaluate方法和calibration方法。

sensitivity

输入

评价每一层量化层对于量化敏感度的指标,默认是MSE(Mean Square Error,均方误差)。

数据类型:string或python instance

默认值:MseSimilarity

返回值说明

无。

函数输出

  • *_qat_mixed_precision.json:混合精度配置文件,描述了模型的每一层应该使用哪一种精度。*代表模型名前缀。
  • *_qat_mixed_precision.cfg:用于量化感知训练的混合精度配置文件。
  • 中间件(可选):包括层量化敏感度、计算量文件、量化因子记录文件、量化配置文件以及量化层的输入数据文件。

    上述文件在搜索完成后会被删除,只有设置日志级别环境变量AMCT_LOG_FILE_LEVEL设置为debug模式才会被保留(详情请参见AMCT(PyTorch))。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import amct_pytorch as amct
# 建立待自动混合精度搜索的网络图结构
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])
# 自动混合精度搜索简易配置文件
config = './amc.cfg'
# 保存结果路径
save_dir = './results/'
# 建立evaluator,有两种方式
# 方式一,自定义evaluator,需继承自amct_pytorch.common.auto_calibration AutoCalibrationEvaluatorBase
evaluator = CustomEvaluator(AutoCalibrationEvaluatorBase) # 实现evaluate和calibration方法
# 方式二,根据data_dir, input_shape, data_types创建evaluator。
evaluator = amct.ModelEvaluator(data_dir, input_shape, data_types)
# sensitivity有两种方式
# 方式一,字符串'MseSimilarity',由AMCT内部实现。
# 方式二,需继承自amct_pytorch.common.auto_calibration SensitivityBase
sensitivity = CustomSensitivity(SensitivityBase) # 实现compare方法
amct.auto_mixed_precision_search(model=model, input_data=input_data, config=config, save_dir=save_dir, evaluator=evaluator, sensitivity='MseSimilarity')