Shape推导与校验
简介
IR实现的cc文件中主要实现如下两个功能:
- 算子参数的校验,实现程序健壮性并提高定位效率。
- 根据算子的输入张量描述、算子逻辑及算子属性,推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息。这样算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
在“op_proto/算子名称.cc”实现Verify和InferShape方法时不需要声明,直接实现即可。
引入依赖的头文件
#include "算子名称.h" #include <vector> #include <string>
头文件 |
目录 |
作用 |
---|---|---|
算子名称.h |
算子IR注册中实现的IR头文件 |
包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。 |
string |
C++标准库。 |
包含该头文件,可使用string类构造对象并调用string相关接口。 |
vector |
C++标准库。 |
包含该头文件,可使用vector类模板并调用vector相关接口。 |
InferShape函数实现
算子IR中InferShape的定义可以使用如下接口:
IMPLEMT_COMMON_INFERFUNC(func_name)
此接口自动生成一个类型为Operator类的对象op,开发者可直接调用Operator类接口进行InferShape的实现。其中,func_name:用户自定义。
- 将输入描述直接赋给输出描述的实现样例如下所示:
1 2 3 4 5 6 7 8 9 10 11
IMPLEMT_COMMON_INFERFUNC(SoftmaxInferShape) { TensorDesc tensordesc_output = op.GetOutputDescByName("y"); tensordesc_output.SetShape(op.GetInputDescByName("x").GetShape()); tensordesc_output.SetDataType(op.GetInputDescByName("x").GetDataType()); tensordesc_output.SetFormat(op.GetInputDescByName("x").GetFormat()); (void)op.UpdateOutputDesc("y", tensordesc_output); return GRAPH_SUCCESS; }
- 输出描述需要根据算子逻辑进行计算得到的样例如下所示:
IMPLEMT_COMMON_INFERFUNC(NotEqualInferShape) { Shape x_shape = op.GetInputDescByName("x1").GetShape(); Shape y_shape = op.GetInputDescByName("x2").GetShape(); TensorDesc td = op.GetOutputDescByName("y"); std::vector<int64_t> dims_x = x_shape.GetDims(); std::vector<int64_t> dims_y = y_shape.GetDims(); if (dims_x.size() < dims_y.size()) { std::vector<int64_t> dims_tmp = dims_x; dims_x = dims_y; dims_y = dims_tmp; } if (dims_x.size() != dims_y.size()) { int dec = dims_x.size() - dims_y.size(); for (int i = 0; i < dec; i++) { dims_y.insert(dims_y.begin(), (int64_t)1); } } std::vector<int64_t> dim_vec; for (size_t i = 0; i < dims_x.size(); i++) { if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { printf( "The %s op dimensions does not match the broadcast rule(%lu %lu).",op.GetName().c_str(), dims_x[i], dims_y[i]); } int64_t dims = dims_x[i] > dims_y[i] ? dims_x[i] : dims_y[i]; dim_vec.push_back(dims); } td.SetShape(ge::Shape(dim_vec)); td.SetDataType(DT_BOOL); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; }
- 若存在动态输入输出,输入输出按列表形式处理,样例如下所示:
IMPLEMT_COMMON_INFERFUNC(BatchInfer) { for (size_t i = 0; i < op.GetInputsSize(); ++i) { Shape out_shapes; if (ReplaceDim(op.GetInputDesc(i).GetShape(), 0, ge::UNKNOWN_DIM, out_shapes, op.GetName().c_str()) == GRAPH_FAILED) { return GRAPH_FAILED; } auto y_tensor_type = op.GetDynamicInputDesc("x_tensors", i).GetDataType(); TensorDesc output_desc = op.GetDynamicOutputDesc("y_tensors", i); output_desc.SetShape(out_shapes); output_desc.SetDataType(y_tensor_type); op.UpdateDynamicOutputDesc("y_tensors", i, output_desc); } Shape scalar_shape; Scalar(scalar_shape); TensorDesc y_desc = op.GetOutputDesc("y_id"); y_desc.SetShape(scalar_shape); y_desc.SetDataType(DT_INT64); op.UpdateOutputDesc("y_id", y_desc); std::vector<int64_t> dims = { ge::UNKNOWN_DIM, 3 }; TensorDesc output_desc_batch_index = op.GetOutputDesc("y_index"); output_desc_batch_index.SetShape(Shape(dims)); output_desc_batch_index.SetDataType(DT_INT64); op.UpdateOutputDesc("y_index", output_desc_batch_index); return GRAPH_SUCCESS; } // ReplaceDim函数定义如下 graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name) { if(shape.GetDims() == UNKNOWN_RANK) { out = Shape(ge::UNKNOWN_SHAPE); return GRAPH_SUCCESS; } int64_t dim_index = dim_index_in; if (dim_index < 0) { dim_index = (int64_t)s.GetDimNum() + dim_index; } std::vector<int64_t> dims = s.GetDims(); dims[dim_index] = new_dim; out = Shape(dims); return GRAPH_SUCCESS; }
Verify函数实现
算子Verify函数的实现使用如下接口:
IMPLEMT_VERIFIER (OpType, func_name)
传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见Operator类。
- OpType:自定义算子的类型。
- func_name:自定义的verify函数名称。
Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。
实现样例如下所示:
IMPLEMT_VERIFIER(Pow, PowVerify) { DataType input_type_x = op.GetInputDescByName("x").GetDataType(); DataType input_type_y = op.GetInputDescByName("y").GetDataType(); if (input_type_x != input_type_y) { return GRAPH_FAILED; } return GRAPH_SUCCESS; }
Infershape与Verify函数注册
调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:
COMMON_INFER_FUNC_REG(OpType, func_name); VERIFY_FUNC_REG(OpType, func_name);
func_name即为InferShape函数实现与Verify函数实现中的func_name。
父主题: 算子原型定义