算子原型定义
简介
开发者进行算子原型定义开发时需要实现如下两个文件:
- 在算子名称.h头文件中进行算子原型的注册。
- 在算子名称.cc文件中进行校验函数与shape推导函数的实现。
本章节将以TBE算子开发样例工程中“Add”算子为例讲解算子原型定义的开发,供开发者参考。
进入“op_proto”目录,编写IR实现文件“add.h”和“add.cc”,将算子注册到算子原型库中。网络运行时,GE会调用算子原型库的校验接口进行基本参数的校验,校验通过后,会根据原型库中的推导函数推导每个节点的输出shape与dtype,进行输出tensor的静态内存的分配。
AI CPU算子原型定义的开发可参考算子开发样例工程中“reshape_cust”算子,IR实现文件为“reshape_cust.h”和“reshape_cust.cc”。
add.h实现
MindStudio已在add.h头文件中生成了算子注册的代码模板文件,开发者可根据需要进行修改,Add算子的原型定义如下所示:
#ifndef GE_OPS_OP_PROTO_ADD_H_ //条件编译 #define GE_OPS_OP_PROTO_ADD_H_ //进行宏定义 #include "graph/operator_reg.h" namespace ge { REG_OP(Add) //算子类型名称 .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .OP_END_FACTORY_REG(Add) } #endif //GE_OPS_OP_PROTO_ADD_H
- REG_OP(Add)中的Add为算子注册到昇腾AI处理器中的Type,第三方框架(TensorFlow/ONNX/Caffe)需要与算子适配插件实现中REGISTER_CUSTOM_OP("Add")中的算子类型保持一致。
- .INPUT与.OUTPUT分别为算子的输入、输出Tensor的名称与数据类型,输入输出的顺序需要与算子代码实现(TBE DSL)函数形参顺序以及算子信息库定义中参数的顺序保持一致。
add.cc实现
开发者需要在add.cc中实现InferShape与Verify方法。
- Verify函数,即如下代码示例中的IMPLEMT_VERIFIER(Add, AddVerify)函数,用于校验Add算子的两个输入的DataType是否一致。
- InferShape函数,即如下代码示例中的IMPLEMT_COMMON_INFERFUNC(AddInferShape)函数,用于推导出算子的输出张量描述,这样在网络运行时就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
add.cc的实现代码如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
#include "./add.h" //IR注册头文件 #include <vector> //可使用vector类模板并调用vector相关接口 #include <string> //C++标准库,可使用string类构造对象并调用string相关接口 namespace ge { bool InferShapeAndTypeAdd(Operator& op, const string& inputName1, const string& inputName2, const string& outputName) { TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str()); DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType(); Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat(); // 针对shape维度大小进行交换 ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape(); ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape(); std::vector<int64_t> dimsX = shapeX.GetDims(); std::vector<int64_t> dimsY = shapeY.GetDims(); if (dimsX.size() < dimsY.size()) { std::vector<int64_t> dimsTmp = dimsX; dimsX = dimsY; dimsY = dimsTmp; } // 对小的shape进行1补齐 if (dimsX.size() != dimsY.size()) { int dec = dimsX.size() - dimsY.size(); for (int i = 0; i < dec; i++) { dimsY.insert(dimsY.begin(), (int64_t)1); } } // 设置输出的shape维度 std::vector<int64_t> dimVec; for (size_t i = 0; i < dimsX.size(); i++) { if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) { return false; } int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; dimVec.push_back(dims); } ge::Shape outputShape = ge::Shape(dimVec); vOutputDesc.SetShape(outputShape); vOutputDesc.SetDataType(inputDtype); vOutputDesc.SetFormat(inputFormat); op.UpdateOutputDesc(outputName.c_str(), vOutputDesc); return true; } //----------------Add------------------- IMPLEMT_VERIFIER(Add, AddVerify) { if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } // Obtains the processing function of the output tensor description. IMPLEMT_COMMON_INFERFUNC(AddInferShape) { if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) { return GRAPH_SUCCESS; } return GRAPH_FAILED; } //Registered inferfunction,Infershape函数注册 COMMON_INFER_FUNC_REG(Add, AddInferShape); //第一个参数为算子的OpType //Registered verify function,Verify函数注册 VERIFY_FUNC_REG(Add, AddVerify); //第一个参数为算子的OpType //----------------Add------------------- } |
父主题: 算子开发过程