文档
注册

freeze_modules

freeze_modules(model, include, exclude) 

根据指定模块列表冻结网络。需在定义优化器之前调用。

参数

  • model(nn.Cell) - 需要冻结的模型实例。
  • include(Optional[List[str]]) - 需要冻结的模块名列表,默认值为None。
    • 模糊匹配列表中所有模块名,挨个将匹配到的模块的“requires_grad”设置为“False”
    • 列表项支持配置符号*,代表任意字符串,格式如 ['*', '*dense*','*.dense.*', '*.dense.*.bias']。
    • 如果不配置符号*,仅传字符串,表示精确匹配。
  • exclude(Optional[List[str]]) - 不冻结的模块名列表,默认值为None。
    • 模糊匹配列表中所有模块名,挨个将匹配到的模块的“requires_grad”设置为“True”
    • 列表项支持配置符号*,代表任意字符串,格式如 ['*', '*dense*','*.dense.*', '*.dense.*.bias']。
    • 如果不配置符号*,仅传字符串,表示精确匹配。
    • “include”“exclude”列表项冲突时,对该项匹配到的模块不做任何处理。

异常

  • TypeError - model参数类型不是nn.Cell。
  • ValueError - “include”“exclude”参数同时为空。
  • TypeError - “include”“exclude”参数不是非空列表。

样例

from tk.graph.freeze_utils import freeze_modules

# 初始化网络结构
model = Network()

# 根据指定模块列表冻结指定模块
freeze_modules(model,include=['*embedding*', 'transformer*', 'dense.weight'], exclude=['transformer.encoder.blocks.*.layernorm*'])

# 定义优化器
... 
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词