大规模并行训练的混合精度选择

大规模分布式并行训练中使用半精度浮点数,一般有两种选择:float16或者bfloat16。这两种格式的区别见半精度浮点数。float16相对于bfloat16有更高的精度,但是表示范围更小。混合精度训练首先需要把模型中适合的参数转移到半精度浮点类型。

如果混合精度训练中选择float16,为了避免表示范围小引起的浮点上溢和下溢,混合精度要结合动态Loss缩放机制,以下是采用float16的混合精度训练典型流程概述:

  1. 保留一份FP32格式的权重主备份,同时优化器状态也应以FP32格式存储。
  2. 将Loss缩放因子S初始化为一个较大的值。
  3. 对每一个训练step:

    1. 将权重复制一份到FP16格式。
    2. 使FP16格式的权重和激活值进行前向传播。
    3. 将最终的Loss乘以缩放因子S。
    4. 使用FP16格式进行后向传播,包括权重、激活值及其对应的梯度。
    5. 若检测到权重梯度中出现Inf或NaN:
      • 减小S值。
      • 跳过当前权重更新步骤,重新开始下一个训练步骤。
    6. 将权重梯度乘以1/S。
    7. 梯度累积或者梯度累积足够步后使用FP32更新主权重。
    8. 如果之前N步都没有看到Inf或者NaN,增加S值。
    9. 在上述过程中,有几处计算必须要以FP32完成。比如主权重的更新,因为累加能够导致精度误差积累,所以必须要以FP32计算。缩放因子必须是FP32类型,甚至1/S的计算要将S转成双精度数求倒数再转回到FP32。

使用BF16格式的半精度数时,因为BF16有更大的表示范围,所以一般无需使用Loss缩放机制。但是BF16数值精度比FP16更差,所以在步骤三的第七点做梯度累积的时候需要使用FP32,否则有可能会因为梯度累积误差导致模型不收敛。另外BF16比FP16多15%的运行时内存,主要原因在于梯度累积时需要转FP32。

PyTorch提供了自动混合精度(AMP)的机制,AMP按需自动调整张量的数据类型(dtype)。例如在AMP autocast上下文时,矩阵乘法matmul的输入张量会被自动转化为半精度浮点类型。AMP也提供了GradScaler,通过自动调整Loss的缩放来防止梯度的下溢和上溢。PyTorch的AMP优化级别使用apex.amp的O1级,这意味着PyTorch AMP使用黑白名单自动决定使用FP16、BF16还是FP32进行计算,但还有一些特定模型相关的精度敏感的运算并不在AMP的自动upcast名单中,需要开发者手动干预。所以使用AMP时,开发者需要对AMP的黑白名单有一定了解。