文档
注册

原始模型转换为Graph

除了可以将算子原型直接构图外,CANN还提供了框架Parser功能,将主流框架的模型格式转换成CANN模型格式。

功能介绍

目前业界开源的深度学习框架(例如TensorFlow,PyTorch、Caffe等),定义模型的格式各有不同,例如TensorFlow通过自定义pb描述静态图和模型,PyTorch通过ONNX规范描述,因此需要通过统一的框架解析功能隔离上层框架差异,通过Parser模块完成解析并转换成昇腾AI处理器支持的CANN模型格式。

涉及的主要接口为:

Parser层目前为用户开放了自定义OpParser和自定义TensorFlow Scope融合规则的能力,如果用户在Parser解析时需要对框架进行更灵活的适配,则可以自定义OpParser或自定义开发TensorFlow Scope融合规则。

  • 自定义OpParser:如果用户需要将原始框架中算子直接映射到CANN中已实现的TBE算子,可直接进行第三方框架的适配,具体请参考算子适配章节。
  • 自定义TensorFlow Scope融合规则:基于TensorFlow构建的神经网络计算图通常由大量的小算子组成,为了实现高性能的计算,往往需要对子图中的小算子进行融合,使得融合后的大算子可以充分利用硬件加速资源。具体请参考TensorFlow Parser Scope融合规则开发指南

基于TensorFlow模型解析

包含的头文件:

1
#include "tensorflow_parser.h"
通过aclgrphParseTensorFlow接口将TensorFlow原始模型转换为Graph,此时Graph保存在内存缓冲区中。
1
2
3
std::string tfPath = "../data/tf_test.pb";
ge::Graph graph1;
auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(),graph1);

同时,支持用户指定parser_params

1
2
3
4
5
6
std::string tfPath = "../data/tf_test.pb";
std::map<ge::AscendString, ge::AscendString> parser_params= {
            {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")},
            {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NHWC")}};
ge::Graph graph1;
auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(), parser_params, graph1);

基于Caffe模型解析

包含的头文件:

1
#include "caffe_parser.h"
通过aclgrphParseCaffe接口将Caffe原始模型转换为Graph,此时Graph保存在内存缓冲区中。
1
2
3
4
std::string caffePath = "../data/caffe_test.prototxt";
std::string weight = "../data/caffe_test.caffemodel";
ge::Graph graph1;
auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), graph1);

同时,支持用户指定parser_params:

1
2
3
4
5
6
7
std::string caffePath = "../data/caffe_test.prototxt";
std::string weight = "../data/caffe_test.caffemodel";
std::map<ge::AscendString, ge::AscendString> parser_params= {
            {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")},
            {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NHWC")}};
ge::Graph graph1;
auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), parser_params, graph1);

基于ONNX模型解析

包含的头文件:

1
#include "onnx_parser.h"
通过aclgrphParseONNXFromMem接口将加载至内存的ONNX模型转换为Graph,此时Graph保存在内存缓冲区中。同时,支持用户指定parser_params:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/* read file in binary format */
FILE *pFile = fopen("./onnx/resnet101.onnx", "rb" );
if(pFile==NULL)
{
    fputs("File error",stderr);
    exit(1);
}

/* get the size of the file */
fseek(pFile, 0, SEEK_END);
long lSize = ftell(pFile);
rewind(pFile);

/* assign memory buffer for the file*/
char *buffer =(char*) malloc(sizeof(char)*lSize);
if(buffer == NULL)
{
    fputs("Memory error", stderr); 
    exit(2);
}
/* copy the file to buffer */
size_t result = fread(buffer, 1, lSize, pFile);
if(result != lSize)
{
    fputs("Reading error", stderr);
    exit(3);
}
std::map<ge::AscendString, ge::AscendString> parser_params= {
            {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")},
            {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NCHW")}};
ge::Graph graph1;
auto onnxStatus = ge::aclgrphParseONNXFromMem(buffer, result, parser_params, graph1);

通过aclgrphParseONNX接口将ONNX原始模型转换为Graph,此时Graph保存在内存缓冲区中。同时,支持用户指定parser_params:

1
2
3
4
5
6
std::string onnxPath = "../data/onnx_test.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params= {
            {ge::AscendString(ge::ir_option::LOG_LEVEL), ge::AscendString("debug")},
            {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NCHW")}};
ge::Graph graph1;
auto onnxStatus = ge::aclgrphParseONNX(onnxPath .c_str(), parser_params, graph1);
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词