tf.compat.v1.train.Saver.restore
功能描述
TensorFlow用于模型加载的接口。
函数原型
def restore(self, sess, save_path)
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
sess |
Session |
必选 |
需要导入模型TensorFlow的Session。 |
save_path |
str |
必选 |
说明:
当前加载文件单个大小上限为500G,并发读取可能会引发系统OOM。 |
返回值说明
- 成功:None。
- 失败:抛出异常。
使用示例
具体使用方法可参考Rec SDK代码仓中的little demo,以下仅提供一个使用的流程示例。
# 1、导入需要的库 import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2、构建计算图 # ... # 3、创建saver saver = tf.compat.v1.train.Saver() # 4、获取rank_id rank_id = get_rank_id() # 5、设置需要加载的模型保存时的训练步数,比如: latest_step = 200 with tf.compat.v1.Session() as sess: saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}")
参考资源
接口调用流程及示例,参见模型迁移与训练。
父主题: TensorFlow相关接口