下载
中文
注册
仅在某个Device上保存Checkpoint数据

仅在某个Device上保存Checkpoint数据

2022/09/20

81

暂无评分
我要评分

问题信息

问题来源产品大类产品子类关键字
官方模型训练TensorFlow分布式训练、checkpoint数据

问题现象描述

分布式训练场景下,如果用户只希望在某个device上保存checkpoint数据,而不希望在其他device上保存checkpoint数据。

原因分析

修改训练脚本。

解决措施

按照如下方法修改训练脚本:

TensorFlow原始代码:

self._classifier=tf.estimator.Estimator(
  model_fn=cnn_model_fn,
  model_dir=self._model_dir,
  config=tf.estimator.RunConfig(
      save_checkpoints_steps=50 if hvd.rank() == 0 else None,
      keep_checkpoint_max=1))

迁移后的代码:

self._classifier=NPUEstimator(
  model_fn=cnn_model_fn,
  model_dir=self._model_dir,
  config=tf.estimator.NPURunConfig(
      save_checkpoints_steps=50 if get_rank_id() == 0 else 0,
      keep_checkpoint_max=1))

本页内容