下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

概述

该脚本转换工具可将PyTorch脚本和TensorFlow 2.x脚本转换成MindSpore代码。

  • 脚本转换工具根据适配规则,对用户脚本进行转换,大幅度提高了脚本迁移速度,降低了开发者的工作量。该脚本转换工具支持对包括但不限于模型列表列出的模型进行转换,转换成功后可直接运行,部分模型需要根据实际情况进行少量适配。
  • 此脚本转换工具当前仅支持PyTorch和TensorFlow 2.x训练脚本转换。
  • 此脚本转换工具能支持模型列表中的工程转换后训练成功且收敛,对最终精度和性能暂不做保证。
  • 此脚本转换工具转换后的训练工程支持在MindSpore1.6及更新版本上运行。

使用限制

  • MindSpore支持两种运行模式(Graph模式和PyNative模式),由于Graph模式存在Python语法限制,当前仅支持转换到PyNative模式,训练性能较Graph模式有所降低。具体差异详见MindSpore文档
  • 当前为了规避MindSpore中数据处理不支持创建Tensor的限制,在PyTorch转换过程中将运行模式设置成了算子同步下发模式,可能存在训练性能的部分降低;用户可通过将context.set_context中的pynative_synchronize=True去除,使用算子异步下发模式提升性能;此时若报错,可检查数据处理部分代码,去除其中的创建Tensor行为,改为使用numpy的ndarray。
  • 当前TensorFlow 2.x迁移结果不支持多卡运行。

模型列表

注意:脚本转换工具的模型列表仅作参考,可参考表1表2

表1 PyTorch模型列表

序号

模型

原始训练工程代码链接参考

备注

1

BERT

https://github.com/codertimo/BERT-pytorch

  • 在转换前,master分支的代码在Loss计算上存在一定问题,需要参考https://github.com/codertimo/BERT-pytorch/issues/32#issuecomment-432877367做少量修改。
  • 转换完成后,该工程在需要安装才能使用,安装步骤如下:
    • 去除requirements.txt文件中的torch项。
    • 执行python3 setup install。
  • 具体使用方式详见仓库README。

2

BiT-M-R101x1

https://github.com/google-research/big_transfer

3

BiT-M-R101x3

4

BiT-M-R152x2

5

BiT-M-R152x4

6

BiT-M-R50x1

7

BiT-M-R50x3

8

BiT-S-R101x1

9

BiT-S-R101x3

10

BiT-S-R152x2

11

BiT-S-R152x4

12

BiT-S-R50x1

13

BiT-S-R50x3

14

Conformer-tiny

https://github.com/pengzhiliang/Conformer

  • 转换前需要将timm库代码放到原始代码根目录下。
  • timm库版本推荐0.3.2。
  • 由于框架限制,当前不支持--repeated-aug,所以训练时需要使用--no-repeated-aug参数。

15

Conformer-small

16

Conformer-base

17

DeiT-tiny

18

DeiT-small

19

DeiT-base

20

DeepFM

https://github.com/shenweichen/DeepCTR-Torch

由于原始代码库中没有DCN和DeepFM的训练脚本,转换前需要进行以下操作:

  1. 拷贝examples/run_din.py文件并将其命名为run_dcn.py。拷贝examples/run_din.py文件并将其命名为run_deepfm.py。
  2. 将run_dcn.py中的第9行替换为from deepctr_torch.models.dcn import DCN。将run_deepfm.py中的第9行替换为from deepctr_torch.models.deepfm import DeepFM。
  3. 将run_dcn.py中的第47行替换为model = DCN(linear_feature_columns=feature_columns, dnn_feature_columns=feature_columns, device=device)。将run_deepfm.py中的第47行替换为model = DeepFM(linear_feature_columns=feature_columns, dnn_feature_columns=feature_columns, device=device)。

21

DIN

22

DCN

23

EfficientNet-B0

https://github.com/lukemelas/EfficientNet-PyTorch

-

24

EfficientNet-B1

25

EfficientNet-B2

26

EfficientNet-B3

27

EfficientNet-B4

28

EfficientNet-B5

29

EfficientNet-B6

30

EfficientNet-B7

31

EfficientNet-B8

32

SqueezeNet

https://github.com/weiaicunzai/pytorch-cifar100

数据集使用cifar-100-bin,可从https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz获取。

33

InceptionV3

34

InceptionV4

35

InceptionResNetV2

36

Xception

37

Attention56

38

StochasticDepth18

39

StochasticDepth34

40

StochasticDepth50

41

StochasticDepth101

42

VGG11

43

VGG13

44

VGG16

45

DenseNet161

46

DenseNet169

47

DenseNet201

48

PreActResNet34

49

PreActResNet50

50

PreActResNet101

51

PreActResNet152

52

ResNeXt152

53

SEResNet34

54

SEResNet50

55

SEResNet101

56

VGG19

https://github.com/kuangliu/pytorch-cifar

57

PreActResNet18

58

DenseNet121

59

ResNeXt29_2x64d

60

MobileNet

61

MobileNetV2

62

SENet18

63

ShuffleNetG2

64

GoogleNet

65

DPN92

66

Res2Net

https://github.com/Res2Net/Res2Net-ImageNet-Training

暂不支持torchvision.models相关接口的转换,需做以下操作。

修改原始工程:

  1. 创建目录res2net_pami/models。
  2. res2net_pami/main.py中,将import torchvision.models as models改为import models。

67

ResNet18

https://github.com/pytorch/examples/tree/master/imagenet

暂不支持torchvision.models相关接口的转换,需做以下操作。

修改原始工程:

  1. 创建目录imagenet/models。
  2. 从torchvision库中拷贝torchvision/models/resnet.py至imagenet/models下,删除from .utils import load_state_dict_from_url语句。
  3. 创建imagenet/models/__init__.py文件,内容为:
    from .resnet import *
  4. main.py中,将import torchvision.models as models改为import models。

68

ResNet34

69

ResNet50

70

ResNet101

71

ResNet152

72

ResNeXt-50(32x4d)

73

ResNeXt-101(32x8d)

74

Wide ResNet-50-2

75

Wide ResNet-101-2

76

ShuffleNetV2

https://github.com/megvii-model/ShuffleNet-Series

-

77

ShuffleNetV2+

78

Swin-Transformer

https://github.com/microsoft/Swin-Transformer

  • 转换前需要将timm库代码放到原始代码根目录下。
  • timm库版本推荐0.4.12。
  • 当前--cfg参数只支持以下四个配置文件:
    • swin_tiny_patch4_window7_224.yaml
    • swin_tiny_c24_patch4_window8_256.yaml
    • swin_small_patch4_window7_224.yaml
    • swin_base_patch4_window7_224.yaml

79

Transformer

https://github.com/SamLynnEvans/Transformer

需要对于该代码仓中脚本依赖的torchtext库进行转换并有如下注意事项:

  • 拷贝转换转换后的torchtext_x2ms到脚本文件夹。
  • 将torchtext_x2ms重命名未torchtext,以保证用户调用的是转换后的torchtext。
  • torchtext版本建议使用0.6.0。

80

UNet

https://github.com/milesial/Pytorch-Unet

MindSpore暂不支持ReduceLROnPlateau,需要替换成其他支持的scheduler。

81

RCNN-Unet

https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets

转换前需要进行以下操作:

  • 由于MindSpore求导存在语法限制,pytorch_run.py中248和252行的注释需要修改为4空格倍数对齐。
  • 模型要求输入图片大小为16的倍数,因此当数据集图片大小不满足16倍数时,需取消pytorch_run.py中121、122行和505、506行的注释,将图片缩放裁剪为16倍数。
  • 当数据集label图片通道为1时,需要在pytorch_run.py的293行尾加入.convert('RGB')将图片转换为3通道。
  • 由于MindSpore中使用ModuleList会导致子层的权重名称改变,需要将pytorch_run.py第350行的torch.nn.ModuleList改为list,避免checkpoint文件保存后无法重新加载。

82

Attention Unet

83

RCNN-Attention Unet

84

Nested Unet

85

ViT-B_16

https://github.com/jeonsworld/ViT-pytorch

86

ViT-B_32

87

ViT-L_16

88

ViT-L_32

89

ViT-H_14

90

R50-ViT-B_16

表2 TensorFlow 2.x模型列表

序号

模型

原始训练工程代码链接参考

备注

1

DenseNet_121

https://github.com/calmisential/Basic_CNNs_TensorFlow2

请根据需要,在configuration.py中进行epoch\batch_size\数据集路径等配置。

2

DenseNet_169

3

EfficientNet_B0

4

EfficientNet_B1

5

Inception_V4

6

MobileNet_V1

7

MobileNet_V2

8

ResNet_101

9

ResNet_50

10

ResNext_101

11

ResNext_50

12

Shufflenet_V2_x0_5

13

Shufflenet_V2_x1_0

14

AFM

https://github.com/ZiyaoGeng/Recommender-System-with-TF2.0

各个网络文件夹均依赖data_process目录,请直接转换Recommender-System-with-TF2.0目录或将data_process复制至网络文件夹下后再进行转换。

15

DCN

16

Deep_Crossing

17

DeepFM

18

NFM

19

PNN

20

FCN

https://github.com/YunYang1994/TensorFlow2.0-Examples/tree/master/5-Image_Segmentation/FCN

parser_voc.py中使用的scipy.misc.imread方法为scipy 1.2.0以前的旧版本API,mindspore最低兼容scipy 1.5.2,因此请使用scipy的官方弃用警告中推荐的imageio.imread。

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词