文档
注册

跨平台保存模型

PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,为了支持NPU训练出的模型权重或模型可以跨平台使用,需要在模型存储前将模型或tensor放在CPU上进行存储,示例如下:

# 将模型放置在cpu上 
model = model.cpu()      
# 创建保存路径      
PATH = "state_dict_model.pt"      
# 保存模型      
torch.save(model.state_dict(), PATH)

PyTorch2.1.0及以后版本已支持跨设备读取权重,不需要模型或tensor放在CPU上进行存储。

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词