PyTorch断点续训精度溢出异常
2024/02/21
175
问题信息
问题来源 | 产品大类 | 产品子类 | 关键字 |
---|---|---|---|
官方 | 模型训练 | PyTorch | -- |
问题现象描述
MT线上集群多机训练时,偶现训练进程挂掉(各种原因),此时需要根据最近保存的ckpt继续训练,拉起后发现训练精度溢出,无法正常训练。
现象截图:
原因分析
NPU为了提高性能,内部包含许多种数据格式。训练时模型CKPT只会保留权重信息、格式,不会保存优化器(optimizer)的信息和格式。此时断点续训并载入优化器,会使用默认的数据格式,导致NPU无法识别,使得后续计算全部溢出。
解决措施
修改保存、载入ckpt的逻辑如下,即在保存ckpt时,额外保存一份data_format.pt用来存储数据格式。载入ckpt时,读取该.pt文件。
代码详情如下所示:
if 'optimizer_format' in checkpoint: logger.info("Loading optimizer weights format from checkpoint...") checkpoint['optimizer']['state'] = checkpoint_npu_format_cast( checkpoint['optimizer']['state'], checkpoint['optimizer_format']) elif os.path.exists(os.path.join(config.OUTPUT, 'optim_format.pt')): logger.info( "Loading optimizer weights format from optim_format.pt...") ckpt_optimize_format = torch.load( os.path.join(config.OUTPUT, 'optim_format.pt')) checkpoint['optimizer']['state'] = checkpoint_npu_format_cast( checkpoint['optimizer']['state'], ckpt_optimize_format)
本页内容