下载
中文
注册

算子原型定义

简介

开发者进行算子原型定义开发时需要实现如下两个文件:
  • 算子名称.h头文件中进行算子原型的注册。
  • 算子名称.cc文件中进行校验函数与shape推导函数的实现。

本章节将以TBE算子开发样例工程中“Add”算子为例讲解算子原型定义的开发,供开发者参考。

进入“op_proto”目录,编写IR实现文件“add.h”和“add.cc”,将算子注册到算子原型库中。网络运行时,GE会调用算子原型库的校验接口进行基本参数的校验,校验通过后,会根据原型库中的推导函数推导每个节点的输出shape与dtype,进行输出tensor的静态内存的分配。

AI CPU算子原型定义的开发可参考算子开发样例工程中“reshape_cust”算子,IR实现文件为“reshape_cust.h”和“reshape_cust.cc”。

add.h实现

MindStudio已在add.h头文件中生成了算子注册的代码模板文件,开发者可根据需要进行修改,Add算子的原型定义如下所示:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
#ifndef GE_OPS_OP_PROTO_ADD_H_       //条件编译
#define GE_OPS_OP_PROTO_ADD_H_       //进行宏定义
#include "graph/operator_reg.h"
namespace ge {
REG_OP(Add)     //算子类型名称
    .INPUT(x1,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .INPUT(x2,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OUTPUT(y,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OP_END_FACTORY_REG(Add)
}

#endif //GE_OPS_OP_PROTO_ADD_H
  • REG_OP(Add)中的Add为算子注册到昇腾AI处理器中的Type,第三方框架(TensorFlow/ONNX/Caffe)需要与算子适配插件实现中REGISTER_CUSTOM_OP("Add")中的算子类型保持一致。
  • .INPUT与.OUTPUT分别为算子的输入、输出Tensor的名称与数据类型,输入输出的顺序需要与算子代码实现(TBE DSL)函数形参顺序以及算子信息库定义中参数的顺序保持一致。

add.cc实现

开发者需要在add.cc中实现InferShape与Verify方法。

  • Verify函数,即如下代码示例中的IMPLEMT_VERIFIER(Add, AddVerify)函数,用于校验Add算子的两个输入的DataType是否一致。
  • InferShape函数,即如下代码示例中的IMPLEMT_COMMON_INFERFUNC(AddInferShape)函数,用于推导出算子的输出张量描述,这样在网络运行时就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
add.cc的实现代码如下所示:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "./add.h"         //IR注册头文件
#include <vector>        //可使用vector类模板并调用vector相关接口
#include <string>        //C++标准库,可使用string类构造对象并调用string相关接口

namespace ge {
bool InferShapeAndTypeAdd(Operator& op,
    const string& inputName1,
    const string& inputName2, const string& outputName)
 {
    TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str());

    DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType();
    Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat();
    // 针对shape维度大小进行交换
    ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape();
    ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape();
    std::vector<int64_t> dimsX = shapeX.GetDims();
    std::vector<int64_t> dimsY = shapeY.GetDims();
    if (dimsX.size() < dimsY.size()) {
        std::vector<int64_t> dimsTmp = dimsX;
        dimsX = dimsY;
        dimsY = dimsTmp;
    }

  // 对小的shape进行1补齐
  if (dimsX.size() != dimsY.size()) {
    int dec = dimsX.size() - dimsY.size();
    for (int i = 0; i < dec; i++) {
      dimsY.insert(dimsY.begin(), (int64_t)1);
    }
  }

    // 设置输出的shape维度
    std::vector<int64_t> dimVec;
    for (size_t i = 0; i < dimsX.size(); i++) {
        if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) {
            return false;
        }

        int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
        dimVec.push_back(dims);
    }
    ge::Shape outputShape = ge::Shape(dimVec);

    vOutputDesc.SetShape(outputShape);
    vOutputDesc.SetDataType(inputDtype);
    vOutputDesc.SetFormat(inputFormat);
    op.UpdateOutputDesc(outputName.c_str(), vOutputDesc);

    return true;
}

//----------------Add-------------------
IMPLEMT_VERIFIER(Add, AddVerify)
{
    if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) {
        return GRAPH_FAILED;
    }
    return GRAPH_SUCCESS;
}


// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(AddInferShape)
{
    if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) {
        return GRAPH_SUCCESS;
    }
    return GRAPH_FAILED;
}


//Registered inferfunction,Infershape函数注册
COMMON_INFER_FUNC_REG(Add, AddInferShape);      //第一个参数为算子的OpType


//Registered verify function,Verify函数注册
VERIFY_FUNC_REG(Add, AddVerify);        //第一个参数为算子的OpType
//----------------Add-------------------
}