适配插件开发(ONNX框架)

您可以参考本章节进行算子适配插件的开发,将ONNX框架的算子映射成适配昇腾AI处理器的算子(下文简称CANN算子),从而完成从ONNX框架调用Ascend C自定义算子的过程。如下样例展示了一个基础的开发流程:
#include "register/register.h"
#include "graph/operator.h"
#include "json.hpp"
namespace domi {
    Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest) {
        //...
    }
    REGISTER_CUSTOM_OP("OpType")
        .FrameworkType(ONNX) 
        .OriginOpType("OriginOpType")
        .ParseParamsByOperatorFn(ParseParamByOpFunc)   //用来注册解析算子属性的函数
        .ImplyType(ImplyType::TVM);  // Ascend C算子实现类型设置为TVM
}
  1. 包含所需头文件。

    • register.h,存储在CANN软件安装后文件存储路径的“include/register/”目录下,包含该头文件,可使用算子注册相关类,调用算子注册相关的接口。
    • operator.h(可选),存储在CANN软件安装后文件存储路径的“include/graph/”目录下,包含该头文件,可以使用Operator类相关接口,获取算子输入输出及属性等算子信息。
    • json.hpp:用于进行ONNX数据定义的解析,将String类型的算子参数定义转换为json格式。若样例工程中未提供“json.hpp”文件,用户可以自行下载,并将“json.hpp”放在工程可以找到的任意路径下,然后包含此头文件即可,下载路径可参见json.hpp

  2. 使用REGISTER_CUSTOM_OP宏,完成CANN算子和ONNX框架的算子映射关系注册。使用方法如下:

    • REGISTER_CUSTOM_OP:注册自定义算子,OpType为算子类型名称,需要与算子原型注册中的OpType保持一致。
    • FrameworkType:ONNX代表原始框架为ONNX
    • OriginOpType:算子在原始框架中的类型。例如自定义算子OpTypeA,对应ONNX算子库版本opset_version=11,应传入“ai.onnx::11::OpTypeA”,当前支持的ONNX版本范围为9~15。
    • ParseParamsByOperatorFn(ParseParamByOpFunc):用来注册解析算子参数实现映射关系的回调函数,需要用户自定义实现回调函数ParseParamByOpFunc。具体实现方式参考3
    • ImplyType:指定算子的实现方式。Ascend C算子实现类型设置为TVM。

  3. 实现回调函数ParseParamByOpFunc。其函数声明如下所示:

    Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest)
    • ParseParamByOpFunc:函数名称,用户自定义。
    • op_src:ONNX框架定义的Operator类对象,包含ONNX模型中自定义的算子属性信息,定义来源于ONNX框架的原始模型文件。
    • op_dest:CANN算子数据结构,保存算子信息,Operator类的详细描述请参见Operator类

    开发者需要在回调函数中实现属性的解析和映射,具体实现方式如下:

    ONNX原始模型中,属性为repeated message类型,对于repeated message类型的参数,可使用GetAttr(const char *name, ge::AscendString &attr_value)接口获取其属性值,然后将AscendString类型的属性值转换为String类型,再将其转换为json格式进行属性字段的解析。

    实现如下所示:

    Status ParseParamLeakyReluAscend(const ge::Operator& op_src, ge::Operator& op_dest) {
        float negative_slope = 0.01f;
        string negative_slope_str;
        AscendString attrs_string;
        // 使用固定属性名称“attribute”获取ONNX算子中的属性,并赋值给AscendString类型对象
        if (ge::GRAPH_SUCCESS == op_src.GetAttr("attribute", attrs_string)) {
          // 转换为json格式
          json attrs = json::parse(attrs_string.GetString());
          for (json attr : attrs["attribute"]) {
            if (attr["name"] == "alpha" && attr["type"] == kTypeFloat) {
              negative_slope_str = attr["f"];  // float type in json has accuracy loss, so we use string type to store it
              negative_slope = atof(negative_slope_str.c_str());
            }
          }
        }
        op_dest.SetAttr("negative_slope", negative_slope);
        return SUCCESS;
    }
    • 当前版本GetAttr与SetAttr接口不支持对原始文件中数据类型为double和uint64的字段进行解析。
    • 使用ATC工具执行模型转换时,对属性的获取情况不会进行强校验。所以进行算子适配插件实现时,若用户调用GetAttr失败,建议根据算子实际情况增加相应的处理逻辑,例如,针对必选属性,可返回失败,针对可选属性,可设置默认值。