Calibrator
功能说明
量化参数配置类,通过Calibrator类封装量化算法。
函数原型
Calibrator(model, cfg, calib_data=None, fuse_module_call_back=None)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
model |
输入 |
待量化模型实例。 |
必选。 数据类型:PyTorch模型。 |
cfg |
输入 |
已配置的QuantConfig类。 |
必选。 数据类型:QuantConfig。 |
calib_data |
输入 |
模型训练数据,可输入真实数据用于Label-Free量化,也可输入虚拟数据来实现Label-Free量化。 |
可选。 数据类型:list[list[Torch.Tensor]] 或list[Torch.Tensor]。
|
fuse_module_call_back |
输入 |
BN融合用户自定义函数,在量化前会调用该回调。 |
可选。 数据类型:function。 如果模型结构特殊,不是conv->bn并列结构的,需要用户传入自定义融合函数。 |
调用示例
from msmodelslim.pytorch.quant.ptq_tools import QuantConfig, Calibrator disable_names = [] input_shape = [1, 3, 224, 224] quant_config = QuantConfig(disable_names=disable_names, amp_num=0, input_shape=input_shape) calib_data = [] image = cv2.imdecode(np.fromfile("./random_image.jpg", dtype=np.uint8), 1) image = cv2.resize(image, (224, 224,), interpolation=cv2.INTER_CUBIC) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = torch.from_numpy(image).permute(2, 0, 1)/255 image = image.unsqueeze(0) calib_data.append([image]) #传入一张随机图片数据,用于提高精度 calibrator = Calibrator(model, quant_config, calib_data=calib_data)
父主题: 训练后量化(PyTorch)