Estimator脚本支持分布式训练

对于Estimator的分布式脚本,迁移工具支持直接迁移成分布式脚本。但如有用户原始脚本是单P脚本,迁移工具迁移后并不能够进行分布式训练,但是用户可以基于迁移后的脚本,做少量手工修改即可支持分布式训练。

工具迁移后的单P脚本:

def cnn_model_fn(features,labels,mode):    
  #搭建网络   
  xxx    
  #计算loss
  xxx    
  #Configure the TrainingOp(for TRAIN mode)    
  if mode == tf.estimator.ModeKeys.TRAIN:      
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) # 使用SGD优化器
    train_op=distributedOptimizer.minimize(loss=loss,global_step=tf.train.get_global_step()) # 最小化loss
    return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op)
...
hook=hk._LoggerHook(FLAGS)	
training_hooks = []
training_hooks.append(hook)
...
estimator.train(train_data_fn, max_steps=num_steps // rank_size, hooks=training_hooks)

手工修改后支持分布式训练(方法一):

def cnn_model_fn(features,labels,mode):    
  #搭建网络   
  xxx    
  #计算loss
  xxx    
  #Configure the TrainingOp(for TRAIN mode)    
  if mode == tf.estimator.ModeKeys.TRAIN:      
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) 
    optimizer = npu_distributed_optimizer_wrapper(optimizer) # 梯度更新
    train_op=distributedOptimizer.minimize(loss=loss,global_step=tf.train.get_global_step()) 
    return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op)
...
hook=hk._LoggerHook(FLAGS)	
training_hooks = []
training_hooks.append(hook)
training_hooks.append(NPUBroadcastGlobalVariablesHook(0,int(os.getenv('RANK_ID','0')))) # 变量广播
...
estimator.train(train_data_fn, max_steps=num_steps, hooks=training_hooks)

手工修改后支持分布式训练(方法二):

def cnn_model_fn(features,labels,mode):    
  #搭建网络   
  xxx    
  #计算loss
  xxx    
  #Configure the TrainingOp(for TRAIN mode)    
  if mode == tf.estimator.ModeKeys.TRAIN:      
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) 
    optimizer = npu_distributed_optimizer_wrapper(optimizer) # 梯度更新
    train_op=distributedOptimizer.minimize(loss=loss,global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op)
...
hook=hk._LoggerHook(FLAGS)	
training_hooks = []
training_hooks.append(hook)
...
estimator.train(train_data_fn, max_steps=num_steps, hooks=npu_hooks_append(training_hooks))  # 变量广播