下载
中文
注册

distill

功能说明

蒸馏接口,将输入的待蒸馏的图结构按照给定的蒸馏量化配置文件进行蒸馏处理,返回修改后的torch.nn.module蒸馏模型。

函数原型

distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)

参数说明

参数名

输入/返回值

含义

使用限制

model

输入

待进行蒸馏量化的原始浮点模型,已加载权重。

数据类型:torch.nn.module

compress_model

输入

修改后的可用于蒸馏的torch.nn.module模型。

数据类型:torch.nn.module

使用约束:该接口输入的模型必须是量化后的压缩模型

config_file

输入

用户生成的蒸馏量化配置文件,用于指定模型network中量化层的配置情况和蒸馏结构。

数据类型:string

使用约束:该接口输入的config.json必须和create_distill_config接口输入的config.json一致

train_loader

输入

训练数据集。

数据类型:

torch.utils.data.DataLoader

使用约束:必须与模型输入大小匹配。

epochs

输入

最大迭代次数。

默认值:1

数据类型:int

lr

输入

学习率。

默认值:1e-3

数据类型:float

sample_instance

输入

用户提供的获取模型输入数据方法的实例化对象。

默认值:None

数据类型:DistillSampleBase

使用约束:必须继承自DistillSampleBase类,并且实现get_model_input_data方法。可参考AMCT安装目录/amct_pytorch/distill/distll_sample.py文件。

loss

输入

用于计算损失的实例化对象。

默认值:None

数据类型:

torch.nn.modules.loss._Loss

optimizer

输入

优化器的实例化对象。

默认值:None

数据类型:

torch.optim.Optimizer

distill_model

返回值

修改后的torch.nn.module蒸馏模型。

默认值:None

数据类型:torch.nn.module

返回值说明

蒸馏后的模型。

函数输出

无。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import amct_pytorch as amct
# 建立待进行蒸馏量化的网络图结构
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
compress_model = compress(model)
input_data = tuple([torch.randn(input_shape)])
train_loader = torch.utils.data.DataLoader(input_data)
loss = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1)

# 蒸馏
distill_model = amct.distill(
                model,
                compress_model
                config_json_file,
                train_loader,
                epochs=1,
                lr=1e-3,
                sample_instance=None, 
                loss=loss,
                optimizer=optimizer)