freeze_modules
freeze_modules(model, include, exclude)
根据指定模块列表冻结网络。需在定义优化器之前调用。
参数
- model(nn.Cell) - 需要冻结的模型实例。
- include(Optional[List[str]]) - 需要冻结的模块名列表,默认值为None。
- 模糊匹配列表中所有模块名,挨个将匹配到的模块的“requires_grad”设置为“False”。
- 列表项支持配置符号*,代表任意字符串,格式如 ['*', '*dense*','*.dense.*', '*.dense.*.bias']。
- 如果不配置符号*,仅传字符串,表示精确匹配。
- exclude(Optional[List[str]]) - 不冻结的模块名列表,默认值为None。
- 模糊匹配列表中所有模块名,挨个将匹配到的模块的“requires_grad”设置为“True”。
- 列表项支持配置符号*,代表任意字符串,格式如 ['*', '*dense*','*.dense.*', '*.dense.*.bias']。
- 如果不配置符号*,仅传字符串,表示精确匹配。
- 当“include”和“exclude”列表项冲突时,对该项匹配到的模块不做任何处理。
异常
- TypeError - model参数类型不是nn.Cell。
- ValueError - “include”和“exclude”参数同时为空。
- TypeError - “include”或“exclude”参数不是非空列表。
样例
from tk.graph.freeze_utils import freeze_modules # 初始化网络结构 model = Network() # 根据指定模块列表冻结指定模块 freeze_modules(model,include=['*embedding*', 'transformer*', 'dense.weight'], exclude=['transformer.encoder.blocks.*.layernorm*']) # 定义优化器 ...
父主题: API接口