下载
中文
注册

分布式训练脚本适配(兼容单卡)

NPU上的分布式部署形态如下图所示,每个Tensorflow进程只管理独享的一张NPU训练卡,多个Tensorflow进程间,通过CANN提供的集合通信接口进行集群同步。单独观察某个worker,可以发现其与NPU上的单卡训练,除额外进行了集群内的集合通信外完全一致。

TF Adapter适配时,将单卡NPU视作集群worker数量为1的分布式部署形态,因而NPU的单卡脚本和分布式脚本最终是一致的。

NPU上执行分布式,相较于单卡NPU训练,主要有三部分的额外适配工作:

  1. worker间变量初值同步

    TF2 Eager模式下,变量在模型生成后即完成初始化,此时需要进行变量初值同步操作,使各个worker上的变量初值一致。

    在模型构建完成后,您应当调用npu.distribute.broadcast接口完成变量初值同步,该接口要求传入需要进行worker间值同步的变量,通常,您可以通过model.trainable_variables来获取全部需要同步的变量。

  2. worker间梯度聚合

    执行训练时,不同worker上产生不同的梯度信息grads,通过对多个worker上的梯度进行聚合计算,可以更准确地评估当前训练的误差情况。

    • 当原始脚本中分步骤计算并更新梯度(例如tf.gradient和opt.apply_gradient)时,则需要调用npu.distribute.all_reduce接口完成梯度聚合运算,该接口要求您传入需要进行worker间聚合计算的梯度以及聚合运算的类型(通常是求平均值)。
    • 当原始脚本中计算和更新梯度操作被集成到同一接口中(例如minimize/model.fit)时,则需要调用npu.distribute.npu_distributed_keras_optimizer_wrapper完成梯度聚合运算。
  3. 不同worker上的数据集分片

    分布式训练时,应当保证每个worker上评估的样本不同,这样才能使得训练结果更符合样本集真实分布,比如您在一个8卡NPU的集群中执行训练,此时一个典型的策略就是第一张NPU卡上训练0-1/8的数据,第二张NPU卡训练1/8-2/8的数据,最后一张卡上训练7/8-8/8的数据。

    • 当数据集为tf.data.Dataset格式时,TF Adapter提供了npu.distribute.shard_and_rebatch_dataset接口帮您实现上述切分动作,该接口要求您传入需要进行集群切分的Dataset(Dataset介绍参考链接)以及集群训练时的全局batch大小,例如:
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
       # 由于需要使用npu.distribue.shard_and_rebatch接口,在脚本开头import npu
       import npu_device as npu
      
        if input_context:
          logging.info(
              'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
              input_context.input_pipeline_id, input_context.num_input_pipelines)
          # 原始的shard逻辑,因为以单机CPU方式启动,所以不会进行实际的shard
          dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 
        # NPU添加的shard逻辑,会根据集群数量,对数据集和全局batch进行切分
        dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size) 
      
    • 当数据集为Numpy数组时,需要调用numpy方法手工对数据集和全局batch进行切分,例如:
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      (x_train, _), (x_test, _) = keras.datasets.mnist.load_data(os.path.join(args.data_path, 'mnist.npz'))
      
      # 根据设备数量均分数据集
      x_trains = np.split(x_train, args.rank_size)
      # 按设备编号取对应的数据集分片
      x_train = x_trains[args.device_id]
      x_tests = np.split(x_test, args.rank_size)
      x_test = x_tests[args.device_id]
      # 对全局batch进行切片
      batch_size = args.batch_size // args.rank_size
      
      mnist_digits = np.concatenate([x_train, x_test], axis=0)
      mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255