文档
注册

算子适配插件实现(TensorFlow/ONNX)

开发者需要进行算子适配插件的开发,实现将第三方框架(TensorFlow/ONNX)网络中的算子进行解析并映射成昇腾AI处理器中的算子。

TensorFlow

MindStudio在“framework/tf_plugin/tensorflow_reshape_cust_plugin.cc”文件已自动生成了ReshapeCust算子的插件代码。

  • 包含头文件。
    1
    2
    //包含该头文件,可使用算子注册类相关,调用算子注册相关的接口,为Ascend-cann-toolkit开发套件包中compiler/include/register/register.h文件
    #include "register/register.h"
    
  • 进行插件注册。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    using namespace ge;    // 需手动添加
    namespace domi {
    // register op info to GE
    REGISTER_CUSTOM_OP("ReshapeCust")
        .FrameworkType(TENSORFLOW)   // type: CAFFE, TENSORFLOW
        .OriginOpType("ReshapeCust")      // name in tf module
        .ParseParamsByOperatorFn(AutoMappingByOpFn);
        .ImplyType(ImplyType::AI_CPU);    // 需手动添加
    }  // namespace domi
    
    • REGISTER_CUSTOM_OP:算子注册到GE的算子类型,根据算子分析,算子类型为“ReshapeCust”
    • FrameworkType:TENSORFLOW,原始框架类型为TensorFlow。
    • OriginOpType:算子在TensorFlow框架中的类型。
    • ParseParamsByOperatorFn:用来注册解析模型的函数,使用AutoMappingFn函数自动实现解析。
    • ImplyType:指定算子的实现方式,ImplyType::AI_CPU表示该算子是AI CPU算子。需要手动添加。

ONNX

MindStudio在“framework/onnx_plugin/xxx_plugin.cc”文件已自动生成了算子的插件代码。

  • 包含头文件。
    1
    2
    //包含该头文件,可使用算子注册类相关,调用算子注册相关的接口,为Ascend-cann-toolkit安装目录/ascend-toolkit/latest/compiler/include/register/register.h文件
    #include "register/register.h"
    
  • 进行插件注册。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    using namespace ge;    // 需手动添加
    namespace domi {
    // Onnx ParseParams
    Status ParseParamAdd(const Message* op_src, ge::Operator& op_dest) {
        // To do: Implement the operator plugin by referring to the Onnx Operator Development Guide.
        return SUCCESS;
    }
    
    // register op info to GE
    REGISTER_CUSTOM_OP("Add")
        .FrameworkType(ONNX)   // Operator name with the original framework
        .OriginOpType("")      // Set the original frame type of the operator
        .ParseParamsByOperatorFn(ParseParamAdd)// Registering the callback function for parsing operator parameters 
        .ImplyType(ImplyType::TVM);    // 需手动添加
    }  // namespace domi
    
    • REGISTER_CUSTOM_OP:注册自定义算子,Add作为注册到GE中的算子类型,可以任意命名但不能和已有的算子命名冲突,且需要与原型注册中的OpType保持一致。
    • FrameworkType:ONNX代表原始框架为ONNX
    • OriginOpType:算子在原始框架中的类型,需要用户填写。例如自定义算子Add,对应ONNX算子库版本opset_version=11的原始框架类型为“ai.onnx::11::Add”,此处填写为OriginOpType("ai.onnx::11::Add") 。
    • ParseParamsByOperatorFn(ParseParamAdd):用来注册解析算子属性的函数,需要用户自定义实现回调函数ParseParamAdd。

      回调函数ParseParamAdd的声明如下所示:

      Status ParseParamAdd(const ge::Operator& op_src, ge::Operator& op_dest)
      • ParseParamAdd:函数名称,用户自定义,需要保持唯一。
      • op_src:ONNX框架定义的Operator类对象,包含ONNX模型中自定义的算子属性信息,定义来源ONNX框架的原始模型文件。
      • op_dest:CANN算子数据结构,保存算子信息。
    • ImplyType:指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子,需要手动添加。
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词