混合精度简介

大语言模型的成功证明了增大神经网络的参数规模能够提升模型性能,但同时也增大了对加速器内存、算力及通信传输带宽的要求。为了减少内存占用加快收敛速度,大模型训练往往采用16位半精度浮点格式,例如float16或者bfloat16。

大量实验已经证明可以使用半精度浮点数训练大模型,也不会对模型性能带来显著影响,相反低精度计算作为正则化的一部分,反而能够给模型泛化能力带来好处。但目前低精度训练对模型的统计学影响也并不那么清晰,所以整个训练过程单纯使用低精度浮点运算非常具有挑战性。

在此情况下,依赖实践经验的混合精度技术成了一个现实的选择。混合精度训练在训练任务中组合地使用不同的数字格式,降低了对内存,算力和通信带宽的需求,极大提高了训练速度。

本节将介绍大规模并行训练中对浮点精度敏感的计算场景,力争帮助用户将浮点运算精度相关的模型收敛问题消弭在编码阶段。