get_distill_model
功能说明
模型蒸馏接口,将用户提供教师模型、学生模型根据蒸馏配置进行组合,返回一个DistillDualModels实例,用户对DistillDualModels 实例进行训练。
由于PyTorch、MindSpore下蒸馏实现存在差异,对DistillDualModels实例的使用也存在如下区别。
- PyTorch下,DistillDualModels实例前向传播后返回三个数据,分别为soft label计算得到的loss、student模型的原始输出、teacher模型的原始输出。若需要获取hard lable的loss,需用户自行根据student模型的原始输出计算,并调用DistillDualModels实例的get_total_loss()方法,获取soft label和hard label的综合loss。
- MindSpore下会自动计算所有loss,无需手动计算hard label。
函数原型
get_distill_model(teacher, student, config)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
teacher |
输入 |
教师模型。 |
必选。 数据类型:MindSpore模型或PyTorch模型。 |
student |
输入 |
学生模型。 |
必选。 数据类型:MindSpore模型或PyTorch模型。 |
config |
输入 |
蒸馏的配置。 |
必选。 数据类型:KnowledgeDistillConfig对象。 |
调用示例
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig, get_distill_model #定义配置 distill_config = KnowledgeDistillConfig() distill_config. set_hard_label (0.5, 0) \ .add_inter_soft_label({ 't_module': 'uniter.encoder.encoder.blocks.11.output', 's_module': 'uniter.encoder.encoder.blocks.5.output', 't_output_idx': 0, 's_output_idx': 0, "loss_func": [{"func_name": "KDCrossEntropy", "func_weight": 1}], 'shape': [2048] }) #传入参数,返回蒸馏模型 distill_model = get_distill_model(teacher_model, student_model, distill_config)
父主题: 蒸馏接口