在模型运行时遇到报错“terminate called after throwing an instance of 'c10::Error' what(): 0 INTERNAL ASSERT”

报错信息

问题分析

执行代码后出现报错。

import torch
import torch_npu

def test_cpu():
    input = torch.randn(2000, 1000).detach().requires_grad_()
    output = torch.sum(input)
    output.backward(torch.ones_like(output))

def test_npu():
    input = torch.randn(2000, 1000).detach().requires_grad_().npu()
    output = torch.sum(input)
    output.backward(torch.ones_like(output))

if __name__ == "__main__":
    test_cpu()
    torch_npu.npu.set_device("npu:0")
    test_npu()

在运行backward运算时,若没有设置device,程序会自动默认初始化device为0,相当于执行了set_device("npu:0")。由于目前不支持切换device进行计算,若再通过set_decice()方法手动设置device设备,则可能出现该错误。

处理方法

在运行backward运算前,通过set_decice()方法手动设置device。

原代码如下:

if __name__ == "__main__":
    test_cpu()
    torch_npu.npu.set_device("npu:0")
    test_npu()

修改后代码如下:

if __name__ == "__main__":
    torch_npu.npu.set_device("npu:0")
    test_cpu()
    test_npu()