训练状态监控

在大模型训练过程中,为确保训练的稳定性与有效性,需密切关注多项关键指标以评估训练状态,其中包括但不限于 perplexity (PPL)、gradient norm (GNorm)、activation norm、内存占用情况以及Loss scale等参数。推荐采用TensorBoard工具进行数据可视化。
图1 TensorBoard上的数据可视化

另外,在模型的训练中,我们可以通过PyTorch中的hook机制对容易出现问题的某些层配置hook,监控这些层的梯度信息,及时处理出现异常的step以减少对模型训练效果的影响。具体操作如下:

  1. 在train方法中获取模型结构后,将模型传入指定的collector中。

  2. 对指定层的tensor注册tensor hook,该类型hook只返回对应tensor的梯度信息,该hook会在每份micro batch数据完成反向传播后调用,即对于每份micro batch数据均有梯度值的记录。

对该梯度信息进行监控和分析,可以检查训练中的一些异常状态,收集梯度值时,建议只采集梯度的最大值、最小值、平均值等统计信息作为训练状态的监控指标。

除了梯度爆炸问题的监控,我们也要关注梯度下溢问题。特别是对于FP16,如果梯度出现大量0,往往意味着训练出现不正常,这时候需要停止训练,找出问题(例如伪FP32,实际是混合精度)。如果训练器使用了Loss scale,则监控是Loss scale是必须的。