加载权重时遇到报错“load state_dict error.”
问题描述
问题分析
模型训练后保存的state_dict的key值与加载时state_dict的key值不一致,保存时会在每个key的最前面增加一个module前缀。
处理方法
加载权重时先遍历state_dict字典,修改key值,并使用新建的字典。具体用例参考以下内容:
ckpt = torch.load("checkpoint.pth", map_location=loc) # model.load_state_dict(ckpt['state_dict']) state_dict_old = ckpt['state_dict'] state_dict = {} for key, value in state_dict_old.items(): key = key[7:] state_dict[key] = value model.load_state_dict(state_dict)
父主题: 其他操作相关问题