下载
中文
注册

Calibrator

功能说明

量化参数配置类,通过Calibrator类封装量化算法。

函数原型

Calibrator(model, cfg, calib_data=None, fuse_module_call_back=None)

参数说明

参数名

输入/返回值

含义

使用限制

model

输入

待量化模型实例。

必选。

数据类型:PyTorch模型。

cfg

输入

已配置的QuantConfig类。

必选。

数据类型:QuantConfig。

calib_data

输入

模型训练数据,可输入真实数据用于Label-Free量化,也可输入虚拟数据来实现Data-Free量化。

可选。

数据类型:list[list[Torch.Tensor]] 或list[Torch.Tensor]。

  • 如果不输入数据,在模型支持单个float格式输入且指定了input_shape时,会自动调用Data-Free量化流程。
  • 针对多个输入或者需要自定义输入格式的模型,用户可随机构造输入数据来实现Data-Free量化。

fuse_module_call_back

输入

BN融合用户自定义函数,在量化前会调用该回调。

可选。

数据类型:function。

如果模型结构特殊,不是conv->bn并列结构的,需要用户传入自定义融合函数。

调用示例

from modelslim.pytorch.quant.ptq_tools import QuantConfig, Calibrator
disable_names = []
input_shape = [1, 3, 224, 224]
quant_config = QuantConfig(disable_names=disable_names, amp_num=0, input_shape=input_shape)
calib_data = []
image = cv2.imdecode(np.fromfile("./random_image.jpg", dtype=np.uint8), 1)
image = cv2.resize(image, (224, 224,), interpolation=cv2.INTER_CUBIC)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
image = torch.from_numpy(image).permute(2, 0, 1)/255
image = image.unsqueeze(0)
calib_data.append([image])     #传入一张随机图片数据,用于提高精度
calibrator = Calibrator(model, quant_config, calib_data=calib_data)