本文档提供的样例是基于PyTorch官网的Imagenet数据集训练模型脚本代码main.py,以PyTorch1.8.1为例。
因为当前适配的昇腾PyTorch版本没有torch.backends.mps这个模块,所以需要将原代码中所有mps模块相关代码注释掉后再进行迁移。具体如下:
if not torch.cuda.is_available(): #and not torch.backends.mps.is_available(): print('using CPU, this will be slow')
#elif torch.backends.mps.is_available(): #device = torch.device("mps") #model = model.to(device)
#elif torch.backends.mps.is_available(): #device = torch.device("mps")
#if torch.backends.mps.is_available(): #images = images.to('mps') #target = target.to('mps')
#elif torch.backends.mps.is_available(): #device = torch.device("mps")