定位精度溢出场景
问题现象
当精度数据出现以下现象时模型中可能存在溢出导致的精度问题:
- loss为inf或nan。
- 混合精度动态loss scale模式下loss scale一直降低,降到1以下或者比GPU降低的要更小,导致参数不更新或者loss收敛不好,如下图所示。
图1 溢出导致精度问题
对于上述溢出导致的精度问题可进行算子bug排查和切换数据类型来定位。
算子bug排查
大模型训练中出现溢出时,需要将这些溢出中属于算子bug的过滤出来。
- 使用精度比对工具中PrecisionDebugger接口来进行过滤。
- 参见溢出检测自动性质判定功能使用,根据输入输出中的极值以及单API的溢出现象复现结果来确定溢出是否属于算子bug,饱和模式重点关注过程溢出。
- 参见使用Dump统计量功能使用,保存API部分特定信息进行详细分析,本功能可展示整网所有API输入输出的最大最小值和均值等统计量,用户可以通过这些统计量识别到网络中数值膨胀的初始位置或范围,也可以将NPU和标杆的统计量用beyondcompare等软件比对,从而迅速定界有精度问题的API。
切换数据类型
如果排查算子bug的步骤中没能找到有精度问题的API或者解决了算子bug后仍存在溢出现象,这说明溢出不是算子bug导致的,那么相关API必须切换到数值表示范围更大的数据类型。
- 根据算子bug可以帮助用户分析出需要切换数据类型的API的最小集合。
整理溢出检测工具检测到的非算子bug类溢出,分析其中需要切换数据类型的API的最小集合。通常这个最小集合中包含模型中最先出现溢出的API。另外,一些常见需要切换到高精度的API也可以考虑加入到上述集合中:他们通常是会造成数值分布较大变化的API,如batchnorm、layernorm、softmax等。
- 将上述集合中的API的数据类型从fp16切换为bf16或fp32(可能会导致性能下降),进一步检查精度是否正常。
父主题: 精度异常定位