下载
中文
注册

tf.compat.v1.train.Saver.restore

功能描述

TensorFlow用于模型加载的接口。

函数原型

def restore(self, sess, save_path)

参数说明

参数名

类型

可选/必选

说明

sess

Session

必选

需要导入模型TensorFlow的Session。

save_path

str

必选

  • 模型checkpoint文件的保存路径。
  • 支持本地文件系统和HDFS文件系统,长度范围为[1,150]。
  • 在使用多卡训练加载模型时,多卡save_path可以输入同一加载路径(该路径下保存了多卡训练的结果),各卡会自动加载属于本卡的参数。
说明:

当前加载文件单个大小上限为500G,并发读取可能会引发系统OOM。

返回值说明

  • 成功:None。
  • 失败:抛出异常。

使用示例

具体使用方法可参考Rec SDK代码仓中的little demo,以下仅提供一个使用的流程示例。

# 1、导入需要的库
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2、构建计算图
# ...
# 3、创建saver
saver = tf.compat.v1.train.Saver() 
# 4、获取rank_id
rank_id = get_rank_id()
# 5、设置需要加载的模型保存时的训练步数,比如:
latest_step = 200
with tf.compat.v1.Session() as sess:
    saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}")

参考资源

接口调用流程及示例,参见模型迁移与训练