下载
中文
注册

Shape推导与校验

简介

IR实现的cc文件中主要实现如下两个功能:

  • 算子参数的校验,实现程序健壮性并提高定位效率。
  • 根据算子的输入张量描述、算子逻辑及算子属性,推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息。这样算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。

“op_proto/算子名称.cc”实现Verify和InferShape方法时不需要声明,直接实现即可。

引入依赖的头文件

#include "算子名称.h"     
#include <vector>
#include <string>
表1 头文件说明

头文件

目录

作用

算子名称.h

算子IR注册中实现的IR头文件

包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。

string

C++标准库。

包含该头文件,可使用string类构造对象并调用string相关接口。

vector

C++标准库。

包含该头文件,可使用vector类模板并调用vector相关接口。

InferShape函数实现

算子IR中InferShape的定义可以使用如下接口:

IMPLEMT_COMMON_INFERFUNC(func_name)

此接口自动生成一个类型为Operator类的对象op,开发者可直接调用Operator类接口进行InferShape的实现。其中,func_name:用户自定义。

  • 将输入描述直接赋给输出描述的实现样例如下所示:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    IMPLEMT_COMMON_INFERFUNC(SoftmaxInferShape)
    {
      TensorDesc tensordesc_output = op.GetOutputDescByName("y");
    
      tensordesc_output.SetShape(op.GetInputDescByName("x").GetShape());
      tensordesc_output.SetDataType(op.GetInputDescByName("x").GetDataType());
      tensordesc_output.SetFormat(op.GetInputDescByName("x").GetFormat());
    
      (void)op.UpdateOutputDesc("y", tensordesc_output);
      return GRAPH_SUCCESS;
    }
    
  • 输出描述需要根据算子逻辑进行计算得到的样例如下所示:
    IMPLEMT_COMMON_INFERFUNC(NotEqualInferShape) {
      Shape x_shape = op.GetInputDescByName("x1").GetShape();
      Shape y_shape = op.GetInputDescByName("x2").GetShape();
      TensorDesc td = op.GetOutputDescByName("y");
      std::vector<int64_t> dims_x = x_shape.GetDims();
      std::vector<int64_t> dims_y = y_shape.GetDims();
      if (dims_x.size() < dims_y.size()) {
        std::vector<int64_t> dims_tmp = dims_x;
        dims_x = dims_y;
        dims_y = dims_tmp;
      }
      if (dims_x.size() != dims_y.size()) {
        int dec = dims_x.size() - dims_y.size();
        for (int i = 0; i < dec; i++) {
          dims_y.insert(dims_y.begin(), (int64_t)1);
        }
      }
      std::vector<int64_t> dim_vec;
      for (size_t i = 0; i < dims_x.size(); i++) {
        if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) {
          printf( "The %s op dimensions does not match the broadcast rule(%lu %lu).",op.GetName().c_str(), dims_x[i], dims_y[i]);
        }
        int64_t dims = dims_x[i] > dims_y[i] ? dims_x[i] : dims_y[i];
        dim_vec.push_back(dims);
      }
      td.SetShape(ge::Shape(dim_vec));
      td.SetDataType(DT_BOOL);
      (void)op.UpdateOutputDesc("y", td);
      return GRAPH_SUCCESS;
    }
  • 若存在动态输入输出,输入输出按列表形式处理,样例如下所示:
    IMPLEMT_COMMON_INFERFUNC(BatchInfer)
    {
        for (size_t i = 0; i < op.GetInputsSize(); ++i) {
            Shape out_shapes;
            if (ReplaceDim(op.GetInputDesc(i).GetShape(), 
                    0, ge::UNKNOWN_DIM, out_shapes, op.GetName().c_str()) == GRAPH_FAILED) {
                return GRAPH_FAILED;
            }
            auto y_tensor_type = op.GetDynamicInputDesc("x_tensors", i).GetDataType();
            TensorDesc output_desc = op.GetDynamicOutputDesc("y_tensors", i);
            output_desc.SetShape(out_shapes);
            output_desc.SetDataType(y_tensor_type);
            op.UpdateDynamicOutputDesc("y_tensors", i, output_desc);
        }
    
        Shape scalar_shape;
        Scalar(scalar_shape);
        TensorDesc y_desc = op.GetOutputDesc("y_id");
        y_desc.SetShape(scalar_shape);
        y_desc.SetDataType(DT_INT64);
        op.UpdateOutputDesc("y_id", y_desc);
    
        std::vector<int64_t> dims = { ge::UNKNOWN_DIM, 3 };
        TensorDesc output_desc_batch_index = op.GetOutputDesc("y_index");
        output_desc_batch_index.SetShape(Shape(dims));
        output_desc_batch_index.SetDataType(DT_INT64);
        op.UpdateOutputDesc("y_index", output_desc_batch_index);
        return GRAPH_SUCCESS;
    }
    // ReplaceDim函数定义如下
    graphStatus ReplaceDim(const Shape& s, 
                           int64_t dim_index_in, 
                           int64_t new_dim, 
                           Shape& out, 
                           const char* op_name) {
      if(shape.GetDims() == UNKNOWN_RANK) {
        out = Shape(ge::UNKNOWN_SHAPE);
        return GRAPH_SUCCESS;
      }
      int64_t dim_index = dim_index_in;
      if (dim_index < 0) {
        dim_index = (int64_t)s.GetDimNum() + dim_index;
      }
      std::vector<int64_t> dims = s.GetDims();
      dims[dim_index] = new_dim;
      out = Shape(dims);
      return GRAPH_SUCCESS;
    }

Verify函数实现

算子Verify函数的实现使用如下接口:

IMPLEMT_VERIFIER (OpType, func_name)

传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见Operator类

  • OpType:自定义算子的类型。
  • func_name:自定义的verify函数名称。

Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。

实现样例如下所示:

IMPLEMT_VERIFIER(Pow, PowVerify) {
  DataType input_type_x = op.GetInputDescByName("x").GetDataType();
  DataType input_type_y = op.GetInputDescByName("y").GetDataType();
  if (input_type_x != input_type_y) {
    return GRAPH_FAILED;
  }
  return GRAPH_SUCCESS;
}

Infershape与Verify函数注册

调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:

COMMON_INFER_FUNC_REG(OpType, func_name);  
VERIFY_FUNC_REG(OpType, func_name);

func_name即为InferShape函数实现Verify函数实现中的func_name。