跨平台保存模型
PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,为了支持NPU训练出的模型权重或模型可以跨平台使用,需要在模型存储前将模型或tensor放在CPU上进行存储,示例如下:
# 将模型放置在cpu上 model = model.cpu() # 创建保存路径 PATH = "state_dict_model.pt" # 保存模型 torch.save(model.state_dict(), PATH)
PyTorch2.1.0及以后版本已支持跨设备读取权重,不需要模型或tensor放在CPU上进行存储。
父主题: 模型保存与导出