distill
功能说明
蒸馏接口,将输入的待蒸馏的图结构按照给定的蒸馏量化配置文件进行蒸馏处理,返回修改后的torch.nn.module蒸馏模型。
函数原型
distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
model |
输入 |
待进行蒸馏量化的原始浮点模型,已加载权重。 |
数据类型:torch.nn.module |
compress_model |
输入 |
修改后的可用于蒸馏的torch.nn.module模型。 |
数据类型:torch.nn.module 使用约束:该接口输入的模型必须是量化后的压缩模型 |
config_file |
输入 |
用户生成的蒸馏量化配置文件,用于指定模型network中量化层的配置情况和蒸馏结构。 |
数据类型:string 使用约束:该接口输入的config.json必须和create_distill_config接口输入的config.json一致 |
train_loader |
输入 |
训练数据集。 |
数据类型: torch.utils.data.DataLoader 使用约束:必须与模型输入大小匹配。 |
epochs |
输入 |
最大迭代次数。 |
默认值:1 数据类型:int |
lr |
输入 |
学习率。 |
默认值:1e-3 数据类型:float |
sample_instance |
输入 |
用户提供的获取模型输入数据方法的实例化对象。 |
默认值:None 数据类型:DistillSampleBase 使用约束:必须继承自DistillSampleBase类,并且实现get_model_input_data方法。可参考AMCT安装目录/amct_pytorch/distill/distll_sample.py文件。 |
loss |
输入 |
用于计算损失的实例化对象。 |
默认值:None 数据类型: torch.nn.modules.loss._Loss |
optimizer |
输入 |
优化器的实例化对象。 |
默认值:None 数据类型: torch.optim.Optimizer |
distill_model |
返回值 |
修改后的torch.nn.module蒸馏模型。 |
默认值:None 数据类型:torch.nn.module |
返回值说明
蒸馏后的模型。
函数输出
无。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import amct_pytorch as amct # 建立待进行蒸馏量化的网络图结构 model = build_model() model.load_state_dict(torch.load(state_dict_path)) compress_model = compress(model) input_data = tuple([torch.randn(input_shape)]) train_loader = torch.utils.data.DataLoader(input_data) loss = torch.nn.MSELoss() optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1) # 蒸馏 distill_model = amct.distill( model, compress_model config_json_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer) |