save
功能描述
mxRec训练模型保存接口。
函数原型
def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False)
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
sess |
Session |
必选 |
需要保存模型的Session。 |
save_path |
str |
必选 |
模型保存路径。 |
global_step |
int |
可选 |
在checkpoint文件名补充训练步数,默认值为“None”。 |
latest_filename |
str |
可选 |
protocol buffer文件的可选名称,该文件将包含最新checkpoint列表。 |
meta_graph_suffix |
str |
可选 |
MetaGraphDef文件的后缀,默认为“meta”。 |
write_meta_graph |
bool |
可选 |
是否写入MetaGraph文件,默认为“True”。 |
write_state |
bool |
可选 |
是否写入CheckpointStateProto文件,默认为“True”。 |
strip_default_attrs |
bool |
可选 |
保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”。
|
save_debug_info |
bool |
可选 |
是否保存Debug信息,默认为“False”。
|
返回值说明
- 成功:返回“model_checkpoint_path”,即模型保存路径。
- 失败:抛出异常。
使用示例
import tensorflow as tf saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session() as sess: saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0)
参考资源
接口调用流程及示例,参见模型训练。
父主题: 模型接口