文档
注册

模型前向传播不是单个tensor时,如何进行特殊配置

剪枝过程需要通过前向传播获取相关信息,默认模型输入是一个tensor。实际模型可能是多个变量,或一个dict,如下代码所示。

# 模型forward有多个输入
def forward(self, x, y, z):
    pass
# 模型forward输入是一个dict
def forward(self, {'x':x, 'y':y, 'z':z}):
    pass

这种情况下,需要通过torch.save(data, save_path)存下[x, y, z]列表或{'x':x, 'y':y, 'z':z}字典(batchsize为2),存为.pkl后缀的文件。该pickle文件地址记为pkl_data_path。

  • search_algorithm.input_pickle_path:配置为pkl_data_path。
  • trainer.input_pickle_path(可选): 当待训练模型需要加载预训练模型权重,且预训练模型结构与待训练模型结构不同时,需要配置为pkl_data_path。
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词