并行策略介绍

在大模型训练中,由于数据量和模型复杂度的增加,单个计算节点的计算能力难以满足训练的需求。为了提高训练效率和加速训练过程,通常采用并行策略来将计算任务分配给多个计算节点进行计算。并行策略通常分为数据并行、张量并行、流水并行、序列并行等多种并行模式。在实际应用中,通常会采用同时使用多种并行的混合并行策略,以及多种优化手段,例如使用 ZeRO优化器、重计算等方式,以减少模型对内存的使用,提高训练的效率和加速训练的过程。因此并行策略设计与模型的效率息息相关,在模型调优之前先确定一组或多组较优的并行策略,是至关重要的。

数据并行

数据并行将训练数据划分为多个批次,并将每个批次分配给不同的设备进行并行处理,每张计算卡都并行处理不同批次的数据,然后将结果合并。在数据并行中,每张计算卡都拥有自己的内存和计算资源,可以独立地处理数据,但由于每张计算卡上存在一份完整的模型权重副本,所以对显存会有更多的要求。换而言之,在每个worker之上复制一份模型,这样每个worker都有一个完整模型的副本。输入数据集是分片的,一个训练的小批量数据将在多个worker之间分割;worker定期汇总它们的梯度,以确保所有worker看到一个一致的权重版本。对于无法放进单个worker的大型模型,人们可以在模型之中较小的分片上使用数据并行。

图1 数据并向

张量并行

张量并行是一种基于张量(Tensor)概念的数据处理和分析方法。张量是一种多维数组,通过对其进行分解和并行计算,实现高效的数据处理。其核心思想是将张量分解为多个小张量,然后对这些小张量进行并行计算。这种方法可以显著减少计算时间和内存占用,提高计算效率。在计算过程中,张量被列切分,然后在不同设备上分配不同的列张量,最后在计算完成后对结果进行拼接。

图2 Attention模块的张量并行
图3 MLP模块的张量并行

序列并行

序列并行是一种针对输入序列进行列切分的并行计算方式,它可以在张量并行的基础上进一步提高计算效率。在序列并行中,计算过程中的权重会进行行切分,然后同张量并行一起放置在同一台计算设备上进行计算。完成计算后,会进行加操作,从而得到最终的结果。与其他并行计算方式相比,序列并行并不会增加额外的通信量,因此在开启张量并行的同时建议也同步开启序列并行。此外,序列并行还可以与现有的数据并行、流水线并行一起使用,从而实现更高效的4D并行计算。

图4 layernorm和dropout序列并行

流水线并行

流水线并行是将同一个任务分成多个阶段,每个阶段由不同的处理器处理,然后将结果传递给下一个阶段,以实现并行计算。此种并行方式,特别适用于深层模型,可以充分利用多个设备的计算能力,大幅减小模型对内存的依赖,且计算的通信可以重叠,所以对通信需求较低。但流水线之间存在空闲 bubble,对训练效率有较大影响。存在以下多种模式:

ZeRO优化

模型的内存占用主要由优化器状态(optimizer states)、梯度(gradient)和权重(weight)组成。在传统数据并行下,每个设备都会复制一份模型数据,占用了大量显存,而ZeRO 优化的主要原理就是将这些数据进行切分,分别存在各个设备上,在需要用到的时候通过集合通信进行同步,不需要了就释放掉相关显存,来减少显存的使用峰值,是一种典型的时间换空间的办法。总共有 ZeRO 1、ZeRO 2、ZeRO 3 三个阶段:

  1. ZeRO 1:模型的权重更新,是需要优化器状态参与计算,得出新的权重。但是在正向和反向传播中,优化器状态并不参与到其中的计算,所以我们可以把优化器状态进行切分,每个设备都只保存一部分。在正向反向计算完成后,每个设备都只负责更新那一部分的权重,最后再进行集合通信对新的模型权重进行同步。
  2. ZeRO 2:是在 ZeRO 1 的基础上,对梯度信息(gradient)也进行切分,既然我们只有一部分的优化器状态,那么其实我们也只是需要那一部分的梯度信息来更新模型权重。在一个 Layer 的梯度被计算出来后会通过集合通信来聚合,聚合后的梯度信息只会被需要的设备保存,其他不需要的设备就会释放掉,以节省显存使用。注意,ZeRO 2 开始便与流水并行不兼容。
  3. ZeRO 3:则是在 ZeRO 1 和 ZeRO 2 的基础上,对模型权重也进行切分,分配的不同的设备上。只在需要的时候通过集合通讯进行同步,在计算完成后便立即释放。相较于 ZeRO 1 和 ZeRO 2,此优化更加复杂,虽然节省了更显存,但设备间的通信数据也大幅增加,对模型训练的效率影响也较大。
图8 ZeRO切分示意图

重计算

重计算是一种以时间换空间的策略。重计算的思路是在正向计算中不再保存(反向计算依赖的)中间结果,而是在反向计算时进行重新计算,从而减少模型训练中的显存需求。重计算技术允许更大规模的模型和更长的序列进行计算,特别是对于内存受限的环境。然而,重计算也可能增加计算开销,因为需要重新计算一些中间结果。因此,在实际应用中,需要综合考虑显存和计算的平衡,选择适当的重计算策略。

以Megatron为例:

集群机器排布

16卡集群下,tp=2,pp=4,dp=2的集群排布示意图如下:

图9 集群机器排布

在进行集群机器排布时,优先考虑通信量最大的并行方式,将相关机器放在同一节点/子集群中,以减少跨机器之间的通信开销。这样可以提高通信效率,加快训练速度。以下是对每层Transformer的通讯量分析。

在本节中使用以下符号:b表示micro batch size, s表示sequence length , h表示hidden dimension

综上,每层Transformer的TP、DP、PP通讯量对比如下表

表1 每层Transformer的TP、DP、PP通讯量对比

张量并行(TP)

数据并行(DP)

流水线并行(PP)

8*b*s*h

16h*h

2*b*s*h

上表简化后,每层Transformer的TP、DP、PP通讯量对比(简化)如下表:

表2 每层Transformer的TP、DP、PP通讯量对比(简化)

张量并行(TP)

数据并行(DP)

流水线并行(PP)

4*b*s

8h

b*s

需注意,最佳的机器排布策略仍然需要根据具体的训练任务、模型特性和集群配置进行实验和测试,以找到最适合任务需求和硬件环境的机器排布方案。比如若集群中的机器具有不同的带宽和延迟特性,可以根据通信开销将通信量大的机器放在同一子集群中,若流水线中的某些阶段的通信量较大,可以考虑将涉及这些阶段的机器放在同一子集群中,以减少跨机器的通信开销。