文档
注册

算子原型定义

TBE算子开发流程将以算子开发样例工程中“Add”算子为例,供开发者参考。

进入“op_proto/”目录,编写IR实现文件“add.h”和“add.cc”,将算子注册到算子原型库中。网络运行时,GE会调用算子原型库的校验接口进行基本参数的校验,校验通过后,会根据原型库中的推导函数推导每个节点的输出shape与dtype,进行输出tensor的静态内存的分配。

add.h实现

MindStudio已在add.h文件中生成了算子注册的代码模板文件,开发者可根据需要进行修改,Add算子的原型定义如下所示:

#ifndef GE_OPS_OP_PROTO_ADD_H_
#define GE_OPS_OP_PROTO_ADD_H_
#include "graph/operator_reg.h"
namespace ge {
REG_OP(Add)
    .INPUT(x1,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .INPUT(x2,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OUTPUT(y,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OP_END_FACTORY_REG(Add)
}

#endif //GE_OPS_OP_PROTO_ADD_H

add.cc实现

开发者需要在add.cc中实现InferShape与Verify方法。

  • Verify函数,即如下代码示例中的IMPLEMT_VERIFIER(Add, AddVerify)函数,用于校验Add算子的两个输入的DataType是否一致。
  • InferShape函数,即如下代码示例中的IMPLEMT_COMMON_INFERFUNC(AddInferShape)函数,用于推导出算子的输出张量描述,这样在网络运行时就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
add.cc的实现代码如下所示:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "./add.h"         //IR注册头文件
#include <vector>        //可使用vector类模板并调用vector相关接口
#include <string>        //C++标准库,可使用string类构造对象并调用string相关接口

namespace ge {
bool InferShapeAndTypeAdd(Operator& op,
    const string& inputName1,
    const string& inputName2, const string& outputName)
 {
    TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str());

    DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType();
    Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat();
    // 针对shape维度大小进行交换
    ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape();
    ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape();
    std::vector<int64_t> dimsX = shapeX.GetDims();
    std::vector<int64_t> dimsY = shapeY.GetDims();
    if (dimsX.size() < dimsY.size()) {
        std::vector<int64_t> dimsTmp = dimsX;
        dimsX = dimsY;
        dimsY = dimsTmp;
    }

  // 对小的shape进行1补齐
  if (dimsX.size() != dimsY.size()) {
    int dec = dimsX.size() - dimsY.size();
    for (int i = 0; i < dec; i++) {
      dimsY.insert(dimsY.begin(), (int64_t)1);
    }
  }

    // 设置输出的shape维度
    std::vector<int64_t> dimVec;
    for (size_t i = 0; i < dimsX.size(); i++) {
        if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) {
            return false;
        }

        int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
        dimVec.push_back(dims);
    }
    ge::Shape outputShape = ge::Shape(dimVec);

    vOutputDesc.SetShape(outputShape);
    vOutputDesc.SetDataType(inputDtype);
    vOutputDesc.SetFormat(inputFormat);
    op.UpdateOutputDesc(outputName.c_str(), vOutputDesc);

    return true;
}

//----------------Add-------------------
IMPLEMT_VERIFIER(Add, AddVerify)
{
    if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) {
        return GRAPH_FAILED;
    }
    return GRAPH_SUCCESS;
}


// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(AddInferShape)
{
    if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) {
        return GRAPH_SUCCESS;
    }
    return GRAPH_FAILED;
}


//Registered inferfunction,第一个参数为算子的OpType
COMMON_INFER_FUNC_REG(Add, AddInferShape);      


//Registered verify function,第一个参数为算子的OpType
VERIFY_FUNC_REG(Add, AddVerify);
//----------------Add-------------------
}
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词