使用MMEngine进行断点续训时报错
报错信息
当PyTorch版本为2.1.0,在NPU上执行多卡训练,通过MMEngine进行断点续训时报如下错误:
RuntimeError: Attempted to set the storage of a tensor on device "npu:X" to a storage on different device "npu:0"
问题分析
PyTorch2.1.0多卡训练断点续训加载权重时,处理自定义设备会默认将权重都放到0卡上。
处理方法
- 根据报错堆栈找到MMEngine中加载预训练权重的代码,如下所示:
checkpoint = self.load_checkpoint(filename, map_location=device)
- 将1中代码修改为如下代码。
import os device_id = os.environ['LOCAL_RANK'] device = get_device() checkpoint = self.load_checkpoint(filename, map_location=f"{device}:{device_id}")
父主题: 其他操作相关问题