下载
中文
注册

quantize_model

功能说明

训练后量化接口,根据用户设置的量化配置文件对网络结构进行改图处理,插入权重量化和数据量化等相关算子,然后返回修改后的网络。

函数原型

network = quantize_model(config_file, network, *input_data)

参数说明

参数名

输入/返回值

含义

使用限制

config_file

输入

用户生成的量化配置文件,用于指定模型network中量化层的配置情况。

数据类型:string

network

输入

MindSpore的网络模型定义。

数据类型:MindSpore的Cell对象

input_data

输入

用户网络输入数据(仅要求数据的format和shape正确,数据本身可以是随机生成的)。

数据类型:可以转化为MindSpore Tensor的对象,比如numpy.ndarray对象。

该参数为可变参数,支持用户网络有多个输入的情况。

network

返回值

修改后的network。

数据类型:MindSpore的Cell对象

返回值说明

返回修改后的network,返回的network需要进行推理获取量化因子。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import amct_mindspore as amct
# 建立待量化的网络结构
network = build_network()
# 生成随机输入数据,数据format和type需要和网络输入要求一致
input_data = np.random.uniform(0.0, 1.0, size=[32, 1, 32, 32]).astype(np.float32)
# 插入量化API
calibration_network = amct.quantize_model(config_file="./configs/config.json", network, input_data)
# 对修改后的模型进行一个batch的推理
model =Model(calibration_network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
model.eval(dataset, dataset_sink_mode=False)