当PyTorch版本为2.1.0,在NPU上执行多卡训练,通过MMEngine进行断点续训时报如下错误:
RuntimeError: Attempted to set the storage of a tensor on device "npu:X" to a storage on different device "npu:0"
PyTorch2.1.0多卡训练断点续训加载权重时,处理自定义设备会默认将权重都放到0卡上。
checkpoint = self.load_checkpoint(filename, map_location=device)
import os device_id = os.environ['LOCAL_RANK'] device = get_device() checkpoint = self.load_checkpoint(filename, map_location=f"{device}:{device_id}")