大规模分布式并行训练中使用半精度浮点数,一般有两种选择:float16或者bfloat16。这两种格式的区别见半精度浮点数。float16相对于bfloat16有更高的精度,但是表示范围更小。混合精度训练首先需要把模型中适合的参数转移到半精度浮点类型。
如果混合精度训练中选择float16,为了避免表示范围小引起的浮点上溢和下溢,混合精度要结合动态Loss缩放机制,以下是采用float16的混合精度训练典型流程概述:
使用BF16格式的半精度数时,因为BF16有更大的表示范围,所以一般无需使用Loss缩放机制。但是BF16数值精度比FP16更差,所以在步骤三的第七点做梯度累积的时候需要使用FP32,否则有可能会因为梯度累积误差导致模型不收敛。另外BF16比FP16多15%的运行时内存,主要原因在于梯度累积时需要转FP32。
PyTorch提供了自动混合精度(AMP)的机制,AMP按需自动调整张量的数据类型(dtype)。例如在AMP autocast上下文时,矩阵乘法matmul的输入张量会被自动转化为半精度浮点类型。AMP也提供了GradScaler,通过自动调整Loss的缩放来防止梯度的下溢和上溢。PyTorch的AMP优化级别使用apex.amp的O1级,这意味着PyTorch AMP使用黑白名单自动决定使用FP16、BF16还是FP32进行计算,但还有一些特定模型相关的精度敏感的运算并不在AMP的自动upcast名单中,需要开发者手动干预。所以使用AMP时,开发者需要对AMP的黑白名单有一定了解。