onnx/pb模型图模式执行
GE提供了针对前端框架导出模型的解析功能,若开发者已经获取到PyTorch或者TensorFlow框架导出的模型(.onnx模型或者.pb模型),可通过GE提供的ATC命令行工具或者C++语言的Parser接口,将这两类模型转化为适配昇腾的模型并基于图模式执行。
MindSpore框架网络导出的.air模型,也可以通过ATC命令行工具转换为适配昇腾的om模型,然后通过AscendCL接口加载执行。
下文仅针对如何将PyTorch或者TensorFlow框架导出的模型通过图模式执行进行详细介绍。
通过ATC命令行工具解析模型并执行
- 使用ATC命令行工具,将PyTorch框架导出的onnx模型或者TensorFlow框架导出的pb模型转换为适配昇腾平台的om离线模式。
使用ATC工具转换onnx模型示例:
atc --model=$HOME/module/resnet50*.onnx --framework=5 --output=$HOME/module/out/onnx_resnet50 --soc_version=<soc_version>
使用ATC工具转换pb模型示例:atc --model=$HOME/module/resnet50_tensorflow*.pb --framework=3 --output=$HOME/module/out/tf_resnet50 --soc_version=<soc_version>
关键参数解释如下,ATC工具支持的详细参数及含义可参见《ATC工具使用指南》。
- --model:原始网络模型文件路径与文件名。
- --framework:原始网络模型框架类型,“3”代表TensorFlow模型,“5”代表ONNX模型。
- --output:存放转换后的离线模型的路径以及文件名,例如,若配置为“$HOME/module/out/tf_resnet50”,则转换后的离线模型存储路径为“$HOME/module/out/”,转换后的离线模型名称为“tf_resnet50.om”。
- 基于AscendCL应用开发接口开发应用程序,加载转换好的om离线模型,执行模型推理。
通过C++语言的Parser接口解析模型并执行
针对不同的框架模型,GE提供了对应的解析接口将其解析并转换为适配昇腾的模型,不同框架模型的关键解析接口及调用示例如下:
- aclgrphParseONNX:解析ONNX模型。
#include "onnx_parser.h" std::string onnxPath = "../data/onnx_test.onnx"; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNX(onnxPath .c_str(),graph1);
- aclgrphParseONNXFromMem:解析加载至内存的ONNX模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "onnx_parser.h" FILE *pFile = fopen("./onnx/resnet101.onnx", "rb" ); if(pFile==NULL) { fputs("File error",stderr); exit(1); } /* get the size of the file */ fseek(pFile, 0, SEEK_END); long lSize = ftell(pFile); rewind(pFile); /* assign memory buffer for the file*/ char *buffer =(char*) malloc(sizeof(char)*lSize); if(buffer == NULL) { fputs("Memory error", stderr); exit(2); } /* copy the file to buffer */ size_t result = fread(buffer, 1, lSize, pFile); if(result != lSize) { fputs("Reading error", stderr); exit(3); } std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")}, {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NCHW")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNXFromMem(buffer, result, parser_params, graph1);
- aclgrphParseTensorFlow:解析TensorFlow模型。
开发者在调用解析接口时,还可以通过parser_params参数配置扩展参数,使用方法可参见《Ascend Graph开发指南》中的原始模型转换为Graph。
完成原始框架模型的解析后,会得到GE的Graph(即上述代码示例中的对象graph1),此时Graph保存在内存缓冲区中,开发者后续可以直接编译运行此Graph,图编译运行的流程可参见下一节Ascend Graph构图。