文档
注册

总体思路

PyTorch大模型训练的精度问题的分析、定位可以参考如下思路:

  1. 大模型训练通常使用多机训练,鉴于多机训练复现问题的成本较高,且影响因子较多,建议用户先减少模型层数,使模型能够单机训练,确认单机训练是否也存在精度问题,若存在,则使用下述手段定位精度问题,使得单机精度达标,然后再恢复层数拉起多机训练。
  2. 若单机精度正常但多机精度异常,有可能是多机通信造成的精度问题,此时可以用精度工具的通信精度检测功能进行定位。部分集合通信算子要求通信域内各rank结果一致,如AllReduce、AllGather等,利用这一特性,工具将多机模型训练中产生的通信输出存盘,并传输到同一节点来比较其一致性,从而确定模型中通信算子的精度是否存在问题。若已排除通信算子异常,则可能是由于网络层数增加放大了累积误差,需要使用精度比对等工具进一步分析。
图1 大模型精度问题定位思路流程图
  1. 参见明确精度异常场景,定位精度问题发生的迭代。
  2. 分析精度问题场景。
    • 若loss为inf或nan、混合精度loss持续下降,参见定位精度溢出场景使用工具进行溢出检测定位。
      • 进行溢出检测性质判定,溢出正常则切换数据类型规避溢出。
      • 工具无法解决问题,使用Dump统计量功能使用定位单API实现问题。
    • 若loss曲线异常,参见API精度问题并使用工具进行算子精度数据dump分析算子问题。
      • 使用精度预检工具,使用预检工具定位单API实现问题。
      • 使用dump比对工具,进行统计量dump比对缩小API范围、指定范围dump比对确认问题API。
  3. 重复以上步骤,直到精度问题解决完毕。若无法解决,请联系华为工程师处理。
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词