原始模型转换为Graph
除了可以将算子原型直接构图外,CANN还提供了框架Parser功能,将主流框架的模型格式转换成CANN模型格式。
功能介绍
目前业界开源的深度学习框架(例如TensorFlow,PyTorch、Caffe等),定义模型的格式各有不同,例如TensorFlow通过自定义pb描述静态图和模型,PyTorch通过ONNX规范描述,因此需要通过统一的框架解析功能隔离上层框架差异,通过Parser模块完成解析并转换成昇腾AI处理器支持的CANN模型格式。
涉及的主要接口为:
- 解析TensorFlow模型:aclgrphParseTensorFlow
- 解析Caffe模型:aclgrphParseCaffe
- 解析ONNX原始模型:aclgrphParseONNX
- 解析加载至内存的ONNX模型:aclgrphParseONNXFromMem
Parser层目前为用户开放了自定义OpParser和自定义TensorFlow Scope融合规则的能力,如果用户在Parser解析时需要对框架进行更灵活的适配,则可以自定义OpParser或自定义开发TensorFlow Scope融合规则。
基于TensorFlow模型解析
包含的头文件:
1
|
#include "tensorflow_parser.h" |
通过aclgrphParseTensorFlow接口将TensorFlow原始模型转换为Graph,此时Graph保存在内存缓冲区中。
1 2 3 |
std::string tfPath = "../data/tf_test.pb"; ge::Graph graph1; auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(),graph1); |
同时,支持用户指定parser_params:
1 2 3 4 5 6 |
std::string tfPath = "../data/tf_test.pb"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")}, {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NHWC")}}; ge::Graph graph1; auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(), parser_params, graph1); |
基于Caffe模型解析
包含的头文件:
1
|
#include "caffe_parser.h" |
通过aclgrphParseCaffe接口将Caffe原始模型转换为Graph,此时Graph保存在内存缓冲区中。
1 2 3 4 |
std::string caffePath = "../data/caffe_test.prototxt"; std::string weight = "../data/caffe_test.caffemodel"; ge::Graph graph1; auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), graph1); |
同时,支持用户指定parser_params:
1 2 3 4 5 6 7 |
std::string caffePath = "../data/caffe_test.prototxt"; std::string weight = "../data/caffe_test.caffemodel"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")}, {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NHWC")}}; ge::Graph graph1; auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), parser_params, graph1); |
基于ONNX模型解析
包含的头文件:
1
|
#include "onnx_parser.h" |
通过aclgrphParseONNXFromMem接口将加载至内存的ONNX模型转换为Graph,此时Graph保存在内存缓冲区中。同时,支持用户指定parser_params:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
/* read file in binary format */ FILE *pFile = fopen("./onnx/resnet101.onnx", "rb" ); if(pFile==NULL) { fputs("File error",stderr); exit(1); } /* get the size of the file */ fseek(pFile, 0, SEEK_END); long lSize = ftell(pFile); rewind(pFile); /* assign memory buffer for the file*/ char *buffer =(char*) malloc(sizeof(char)*lSize); if(buffer == NULL) { fputs("Memory error", stderr); exit(2); } /* copy the file to buffer */ size_t result = fread(buffer, 1, lSize, pFile); if(result != lSize) { fputs("Reading error", stderr); exit(3); } std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")}, {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NCHW")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNXFromMem(buffer, result, parser_params, graph1); |
通过aclgrphParseONNX接口将ONNX原始模型转换为Graph,此时Graph保存在内存缓冲区中。同时,支持用户指定parser_params:
1 2 3 4 5 6 |
std::string onnxPath = "../data/onnx_test.onnx"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")}, {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NCHW")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNX(onnxPath .c_str(), parser_params, graph1); |
父主题: 构建Graph