文档
注册

保存模型

PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,根据模型文件的后续用途会保存为两种格式的模型文件(pth文件和pth.tar文件),以便用于在线推理。

  • 在昇腾PyTorch1.11.0版本中,NPU模型在使用torch.save()进行存储的时候会保存NPU特有的设备信息和数据格式,以便于更好的支持断点训练,这使得保存的pth、pt和pth.tar扩展名文件存在跨平台兼容性问题。为了支持NPU训练出的模型权重或模型可以跨平台使用,需要在模型存储前将模型或tensor放在CPU上进行存储,示例如下:
    # 将模型放置在cpu上 
    model = model.cpu()  
  • PyTorch2.1.0及以后版本已支持跨设备读取权重,不需要模型或tensor放在CPU上进行存储。
  • .pth或.pt扩展名的文件:用于在线推理或导出ONNX格式模型。仅保存模型参数,不保存模型结构,以便压缩文件的体积,可以用Netron等可视化工具打开,样例如图1所示。
    图1 .pth文件

    通过state_dict来保存和加载模型。

    保存模型,示例如下:
    # 创建保存路径
    PATH = "state_dict_model.pt"
    # 保存模型
    torch.save(net.state_dict(), PATH)
    加载模型以用于在线推理,示例如下:
    # 模型文件保存路径
    PATH = "state_dict_model.pt"
    model = TheModelClass(*args, **kwargs)        #根据实际模型定义填写函数和参数
    # 加载模型
    model.load_state_dict(torch.load(PATH))
    model.eval()

    保存为后缀是.pth/.pt的文件时,需要提供模型定义文件,否则后续模型无法部署。

  • .pth.tar扩展名的文件:可用于在线推理或重新加载后继续训练。文件中保存了多个组件,以字典形式保存。常见的组件包括模型和优化器的state_dict、停止时的epoch、最新记录的训练损失以及外部的torch.nn.Embedding层等。如果仅用于部署推理模型,推荐只在.pth.tar扩展名的文件中保存权重信息即模型的state_dict。
    保存模型,示例如下:
    PATH = "checkpoint.pth.tar"
    torch.save({
        'epoch': epoch,
        'loss': loss,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
        ...
    }, PATH)
    加载模型用于推理或恢复训练,示例如下:
    model = TheModelClass(*args, **kwargs)            #根据实际模型定义填写函数和参数
    optimizer = TheOptimizerClass(*args, **kwargs)    #根据实际优化器填写函数和参数
    
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.eval()
    # - or -
    model.train()

    通常情况下,训练图和推理图中对同一个算子处理方式不同(例如BatchNorm和dropout等算子),在输入格式上也有差别。因此在运行推理或导出ONNX模型之前,必须调用model.eval()来将dropout和batch normalization层设置为推理模式。

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词