文档
注册

add_blocks_params

功能说明

PruneConfig类方法,根据自定义参数配置模型剪枝的block,若set_steps选择的步骤包含“prune_blocks”,则需要调用该方法。

函数原型

add_blocks_params(pattern, layer_id_map)

参数说明

参数名

输入/返回值

含义

使用限制

pattern

输入

待剪枝网络layer名称的正则表达式。

必选。

数据类型:string。

例如取值为bert.encoder.layer.(d+)时,表示选取网络中以bert.encoder.layer开头,且后续为数字的网络layer。

layer_id_map

输入

待剪枝网络layer的前后id匹配关系。

必选。

数据类型:dict,key和value的数据类型均为int。

例如,取值为{0: 0, 1: 2, 2: 4}时表示将bert.encoder.layer.0的权重保留至bert.encoder.layer.0,bert.encoder.layer.2的权重保留至bert.encoder.layer.1,bert.encoder.layer.4的权重保留至bert.encoder.layer.2,即预训练权重中bert.encoder.layer.x共有5层,而输入的模型中bert.encoder.layer.x只有3层,通过layer_id_map在剪枝时将权重保留到指定的位置。

调用示例

from msmodelslim.common.prune.transformer_prune.prune_model import PruneConfig
prune_config = PruneConfig()
prune_config.set_steps(['prune_blocks']). \
  add_blocks_params('uniter\.encoder\.encoder\.blocks\.(\d+)\.', {0: 1, 1: 3, 2: 5, 3: 7, 4: 9, 5: 11})
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词