文档
注册

AntiOutlier

功能说明

构建用于异常值的类,并将模型,异常值抑制config,校准数据等传入。

函数原型

AntiOutlier(model, calib_data, cfg: Config, dag=None, logger=None, model_type=None)

参数说明

参数名

输入/返回值

含义

使用限制

model

输入

用于大模型离群值抑制的模型。

必选。

数据类型:PyTorch模型。

calib_data

输入

用于离群值抑制的校准数据。

必选。

数据类型:object。

默认值为None。

输入模板:[[input1],[input2],[input3]]。

cfg

输入

已配置的AntiOutlierConfig类。

可选。

数据类型:Config。

dag

输入

模型图。

可选。

默认为None,采用默认配置即可。

logger

输入

Logger对象。

可选。

数据类型:object。

默认值为None,采用默认配置即可。

model_type

输入

模型类型。

可选。

数据类型:object。

默认值为None。

  • LLaMA类模型传入'Llama'。
  • 非LLaMa模型传''。

调用示例

from modelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
anti_config = AntiOutlierConfig(anti_method="m2")
anti_outlier = AntiOutlier(model, calib_data=dataset_calib, cfg=anti_config, model_type='Llama')
anti_outlier.process() 
calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level='L0') 
calibrator.run(int_infer=False) 
calibrator.save(qaunt_weight_save_path)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词