add_output_soft_label
功能说明
KnowledgeDistillConfig类方法,配置蒸馏的soft label,即student模型和teacher模型的soft label的映射关系,专用于模型的最后一层,非必须调用的方法。
函数原型
add_output_soft_label(config)
参数说明
参数名 |
输入/返回值 |
配置项 |
含义 |
使用限制 |
---|---|---|---|---|
config |
输入 |
t_output_idx |
用于配置t_module输出的index。 若t_module存在多个输出,需要使用该参数指定用于计算loss的输出。若只有一个输出,使用0即可。 |
必选。 数据类型:int。 可以为None。 |
s_output_idx |
用于配置s_module输出的index。 若s_module存在多个输出,需要使用该参数指定用于计算loss的输出。若只有一个输出,使用0即可。 |
必选。 数据类型:int。 可以为None。 |
||
loss_func |
用于指定t_module 与s_module 的loss function,每一个loss function作为一个字典存入该list中,字典内部包含如下字段:
|
必选。 数据类型:list。 字典内参数:
|
调用示例
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig distill_config = KnowledgeDistillConfig() distill_config.set_hard_label (0.5, 0) \ .add_output_soft_label({ 't_output_idx': 0, 's_output_idx': 0, "loss_func": [{"func_name": "KDCrossEntropy", "func_weight": 1}] })