文档
注册

使用MMEngine进行断点续训时报错

报错信息

当PyTorch版本为2.1.0,在NPU上执行多卡训练,通过MMEngine进行断点续训时报如下错误:

RuntimeError: Attempted to set the storage of a tensor on device "npu:X" to a storage on different device "npu:0"

问题分析

PyTorch2.1.0多卡训练断点续训加载权重时,处理自定义设备会默认将权重都放到0卡上。

处理方法

  1. 根据报错堆栈找到MMEngine中加载预训练权重的代码,如下所示:
    checkpoint = self.load_checkpoint(filename, map_location=device)
  2. 1中代码修改为如下代码。
    import os
    device_id = os.environ['LOCAL_RANK']
    device = get_device()
    checkpoint = self.load_checkpoint(filename, map_location=f"{device}:{device_id}")
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词