网络冻结样例代码
如果用户需要通过模型配置文件使用网络冻结能力,请按以下步骤适配:
- 在模型启动脚本中定义参数--advanced_config,用于接收用户传入微调工具包的模型配置文件本地路径。
import argparse parser = argparse.ArgumentParser() parser.add_argument('-ac', '--advanced_config', type=str, required=True) # 参数--advanced_config会自动接收到模型配置文件的本地路径 args = parser.parse_args()
- 参考freeze_from_config小节中的freeze_from_config接口描述,定义yaml模型配置文件,并下游任务训练脚本中调用freeze_from_config接口。
注意接口中config_path参数需配置--advanced_config参数值。
from tk.graph.freeze_utils import freeze_from_config # 初始化网络结构 model = Network() # 冻结指定模块 freeze_from_config(model, config_path=args.advanced_config) # 定义优化器 ...