文档
注册

save

功能描述

mxRec训练模型保存接口。

函数原型

def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False)

参数说明

参数名

类型

可选/必选

说明

sess

Session

必选

需要保存模型的Session。

save_path

str

必选

模型保存路径。

global_step

int

可选

在checkpoint文件名补充训练步数,默认值为“None”

latest_filename

str

可选

protocol buffer文件的可选名称,该文件将包含最新checkpoint列表。

meta_graph_suffix

str

可选

MetaGraphDef文件的后缀,默认为“meta”

write_meta_graph

bool

可选

是否写入MetaGraph文件,默认为“True”

write_state

bool

可选

是否写入CheckpointStateProto文件,默认为“True”

strip_default_attrs

bool

可选

保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”

  • 参数值为“True”,则默认值属性将在接口调用时从NodeDefs中删除。
  • 参数值为“False”,则不进行删除操作。

save_debug_info

bool

可选

是否保存Debug信息,默认为“False”

  • 参数值为“True”,则将图形调试信息保存到一个单独的文件中,该文件位于“save_path”对应的目录中,并在生成文件的扩展名之前添加“_debug”。仅当“write_meta_graph”为`“True”`时,此功能才会生效。
  • 参数值为“False”,则不保存Debug信息。

返回值说明

  • 成功:返回“model_checkpoint_path”,即模型保存路径。
  • 失败:抛出异常。

使用示例

import tensorflow as tf
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
    saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0)

参考资源

接口调用流程及示例,参见模型训练

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词