总体思路
PyTorch训练场景的精度问题分析可以参考以下思路进行精度比对和比对结果分析:
图1 精度调测思路
- 采集数据:
- 整网数据dump比对:通常情况下,用户可以优先选择整网dump数据比对,常见现象为loss收敛不符合预期。具体可参考数据dump比对场景,dump整网数据并进行精度比对。
- 溢出检测:若用户在训练过程中怀疑网络存在溢出问题,常见现象为loss scale不更新,用户可以使用溢出检测。具体可参考溢出检测场景,进行全量溢出检测。若发现API溢出问题,可联系华为工程师求助,可进入昇腾开源社区使用issue进行沟通。排除溢出问题后,如果整网仍存在精度问题,再继续进行整网数据dump比对。
- 定位问题API:检查比对结果,根据余弦相似度和最大绝对误差标准,找出网络执行顺序中第一个不符合精度标准的API。
- 根据问题API堆栈信息回溯代码:针对不符合精度标准的可能问题API,dump其调用栈,定位到具体代码行。
- 分析原因并优化:根据上一步定位到的代码行做单API复现,联系华为工程师提供复现样例求助,可进入昇腾开源社区使用issue进行沟通,或者用没有精度问题的等价API替代来规避有问题的API。
- 重复步骤2到步骤5,直到解决所有精度有问题的API。
父主题: 精度调测