样例代码说明

本文档提供的样例是基于PyTorch官网的Imagenet数据集训练模型脚本代码main.py,以PyTorch1.8.1为例。

因为当前适配的昇腾PyTorch版本没有torch.backends.mps这个模块,所以需要将原代码中所有mps模块相关代码注释掉后再进行迁移。具体如下:

  1. 原代码第147行:
    if not torch.cuda.is_available(): #and not torch.backends.mps.is_available():
        print('using CPU, this will be slow')
  2. 原代码第171行至173行:
    #elif torch.backends.mps.is_available():
        #device = torch.device("mps")
        #model = model.to(device)
  3. 原代码第187至188行:
    #elif torch.backends.mps.is_available():
        #device = torch.device("mps")
  4. 原代码第356行至358行:
    #if torch.backends.mps.is_available():
        #images = images.to('mps')
        #target = target.to('mps')
  5. 原代码第443至444行:
    #elif torch.backends.mps.is_available():
        #device = torch.device("mps")