quantize_model
功能说明
训练后量化接口,根据用户设置的量化配置文件对网络结构进行改图处理,插入权重量化和数据量化等相关算子,然后返回修改后的网络。
函数原型
network = quantize_model(config_file, network, *input_data)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
config_file |
输入 |
用户生成的量化配置文件,用于指定模型network中量化层的配置情况。 |
数据类型:string |
network |
输入 |
MindSpore的网络模型定义。 |
数据类型:MindSpore的Cell对象 |
input_data |
输入 |
用户网络输入数据(仅要求数据的format和shape正确,数据本身可以是随机生成的)。 |
数据类型:可以转化为MindSpore Tensor的对象,比如numpy.ndarray对象。 该参数为可变参数,支持用户网络有多个输入的情况。 |
network |
返回值 |
修改后的network。 |
数据类型:MindSpore的Cell对象 |
返回值说明
返回修改后的network,返回的network需要进行推理获取量化因子。
调用示例
1 2 3 4 5 6 7 8 9 10 |
import amct_mindspore as amct # 建立待量化的网络结构 network = build_network() # 生成随机输入数据,数据format和type需要和网络输入要求一致 input_data = np.random.uniform(0.0, 1.0, size=[32, 1, 32, 32]).astype(np.float32) # 插入量化API calibration_network = amct.quantize_model(config_file="./configs/config.json", network, input_data) # 对修改后的模型进行一个batch的推理 model =Model(calibration_network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) model.eval(dataset, dataset_sink_mode=False) |
父主题: 训练后量化接口