Shape推导与校验

简介

IR实现的cc文件中主要实现如下两个功能:

“op_proto/算子名称.cc”实现Verify和InferShape方法时不需要声明,直接实现即可。

引入依赖的头文件

#include "算子名称.h"     
#include <vector>
#include <string>
表1 头文件说明

头文件

目录

作用

算子名称.h

算子IR注册中实现的IR头文件

包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。

string

C++标准库。

包含该头文件,可使用string类构造对象并调用string相关接口。

vector

C++标准库。

包含该头文件,可使用vector类模板并调用vector相关接口。

InferShape函数实现

算子IR中InferShape的定义可以使用如下接口:

IMPLEMT_COMMON_INFERFUNC(func_name)

此接口自动生成一个类型为Operator类的对象op,开发者可直接调用Operator类接口进行InferShape的实现。其中,func_name:用户自定义。

Verify函数实现

算子Verify函数的实现使用如下接口:

IMPLEMT_VERIFIER (OpType, func_name)

传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见Operator类

Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。

实现样例如下所示:

IMPLEMT_VERIFIER(Pow, PowVerify) {
  DataType input_type_x = op.GetInputDescByName("x").GetDataType();
  DataType input_type_y = op.GetInputDescByName("y").GetDataType();
  if (input_type_x != input_type_y) {
    return GRAPH_FAILED;
  }
  return GRAPH_SUCCESS;
}

Infershape与Verify函数注册

调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:

COMMON_INFER_FUNC_REG(OpType, func_name);  
VERIFY_FUNC_REG(OpType, func_name);

func_name即为InferShape函数实现Verify函数实现中的func_name。