开发流程
该开发流程以工程化算子开发为基础,除了需要提供工程化算子开发中的算子实现文件外,还需要额外交付算子入图的代码文件。本节仅提供算子入图代码文件的开发指导。
假设下图是我们需要使用的网络模型,您可能会想直接逐个算子调用,根据输入tensor得到输出tensor就可以完成网络的运行,但在图模式场景下,实际的网络模型生成过程中,会先进行tensor shape以及datatype的推导。这样可以让我们在图执行之前,就知道各tensor的数据类型和形状,提前校验其正确性;同时提前推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息,算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
下面的网络模型经过shape和datatype推导之后,可以得到灰色底纹框中的推导信息:
除了tiling实现外,算子入图时需要额外提供的实现代码有以下几种:
- datatype推导:根据算子的输入datatype、算子逻辑及算子属性等信息,推理出算子的输出张量datatype。
- shape推导:根据算子的输入shape、算子逻辑及算子属性等信息,推理出算子的输出张量shape。
- 声明数据依赖:部分算子在InferShape时,需要依赖某个输入的具体值才可以进行,这类算子被称为“数据依赖算子”,对应的输入被称为“数据依赖输入”。该类算子在注册时,需要声明其数据依赖输入。
下表列出了不同类型的算子对上述实现代码的要求。
分类 |
对入图实现代码的要求 |
---|---|
根据输入shape可以推导出输出shape。 |
|
依赖输入的value才能推导出输出shape,即数据依赖算子。 如Reshape算子,依赖shape输入的value才能推导出输出shape。 |
|
实际开发时通过固定的datatype和shape推导原型实现推导函数,然后再通过SetInferShape、SetInferDataType接口来关联对应的shape推导函数,样例如下。
namespace ge { static graphStatus InferShape(gert::InferShapeContext *context) { ... return GRAPH_SUCCESS; } static graphStatus InferDataType(gert::InferDataTypeContext *context) { ... return ge::GRAPH_SUCCESS; } } // namespace ge namespace ops { class AddCustom : public OpDef { public: AddCustom(const char* name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("y") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("z") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); // 根据用户的算子调用方式决定需不需要注册 图模式调用方式下需要 this->SetInferShape(ge::InferShape); this->SetInferDataType(ge::InferDataType); this->AICore() .SetTiling(optiling::TilingFunc); // 请替换为实际的昇腾AI处理器型号 this->AICore().AddConfig("ascendxxx"); } }; OP_ADD(AddCustom); } // namespace ops
datatype推导
以AddCustom算子为例,InferDataType的实现如下所示。该样例中输出tensor的数据类型与输入tensor的数据类型相同,所以直接将任意一个输入tensor的数据类型赋给输出tensor即可。
namespace ge { static graphStatus InferDataType(gert::InferDataTypeContext* context) { const auto inputDataType = context->GetInputDataType(0); context->SetOutputDataType(0, inputDataType); return ge::GRAPH_SUCCESS; } } // namespace ge
如下示例则给出了更灵活的datatype推导样例,当输入的数据类型为DT_INT4时,其输出的数据类型为DT_INT32。
ge::graphStatus InferDataTypeForFoo(gert::InferDataTypeContext* context) { if (context->GetInputDataType(0) == DT_INT4) { context->SetOutputDataType(0, DT_INT32); } }
shape推导
简单的shape推导逻辑可以使用Follow接口来表达,比如输出shape和输入shape相同的情况。示例如下:输出“y1”Follow输入“x1”场景,指定Follow模式为SHAPE,此时“y1”的shape将会和“x1”保持一致。
this->Input("x1") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x2") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("y1") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .Follow("x1", FollowType::SHAPE);
无法在原型定义中通过Follow表达的情况需要开发者编写InferShape函数,其原型是确定的,接受一个InferShapeContext作为输入,从此context上可以获取到输入、输出的shape指针等内容。输入shape为const类型,因此InferShape时,输入shape是只读、不允许修改的。InferShape成功后,返回ge::GRAPH_SUCCESS,其他返回值被认为推导失败。推导失败后,执行过程结束退出。
以ReShape算子为例,InferShape的实现如下所示。根据第1个输入(shape输入)的值,Reshape算子将第0个输入(x输入)的shape做变换,并输出到其第0个输出(y输出)上。Reshape的InferShape实现为:
ge::graphStatus InferShapeForReshape(InferShapeContext *context) { const gert::Shape *x_shape = context->GetInputShape(0); // 获取第0个输入的shape const gert::Tensor *shape_tensor = context->GetInputTensor(1); // 获取第1个输入的tensor gert::Shape *output_shape = context->GetOutputShape(0); if (x_shape == nullptr || shape_tensor == nullptr || output_shape == nullptr) { // 防御式编程,不应该出现的场景,打印错误并返回失败 return ge::GRAPH_FAILED; } auto reshape_size = static_cast<int32_t>(shape_tensor->GetShapeSize()); if (reshape_size < 1) { // 防御式编程,不应该出现的场景,打印错误并返回失败 return ge::GRAPH_FAILED; } // 根据原型信息,Reshape的shape输入支持INT32与INT64两类,根据不同的类型进入对应的模板函数中做真正的shape变换操作 if (shape_tensor->GetDataType() == ge::DT_INT32) { int32_t *reshape_data = shape_tensor->GetData<int32_t>(); return ReshapeInferShapeImpl<int32_t>(reshape_data, *x_shape, *output_shape, reshape_size); } else { int64_t *reshape_data = shape_tensor->GetData<int64_t>(); return ReshapeInferShapeImpl<int64_t>(reshape_data, *x_shape, *output_shape, reshape_size); } }
InferShapeContextpublic继承自ExtendedKernelContext,因此ExtendedKernelContext中提供的方法如获取算子type、name、属性等接口均可以在InferShapeContext实例中调用。
- InferShape推导函数和使用Follow接口去Follow shape不能混用,即不支持部分输出采用Infershape推导、部分输出采用Follow推导的情况。若用户同时使用了InferShape函数和Follow接口,以用户的InferShape函数为准,需要保证在InferShape函数中能够推导出所有的输出shape。
- 为了效率考虑,调用InferShape函数时,框架不会为输出shape做初始化,因此,在InferShape函数中,可以认为输出是未初始化的状态。如果在InferShape时,希望通过Append方式操作输出shape,需要先将输出shape的DimNum清零,以防止出现未定义行为。
InferShape时获取属性、输入
在InferShape、Tiling时,可以通过context实例获取算子IR属性值,所谓IR属性,是指在IR注册时定义的属性,以TransData算子为例:
namespace ops { class TransData : public OpDef { public: explicit TransData(const char *name) : OpDef(name) { this->Input("src") ... this->Output("dst") ... this->Attr("src_format") .AttrType(REQUIRED) .String(); this->Attr("dst_format") .AttrType(REQUIRED) .String(); this->Attr("group") .AttrType(OPTIONAL) .Int(1); ... } }; OP_ADD(TransData); } // namespace ops
其原型定义中声明了src_format、dst_format、group三个属性,可以通过如下方式获取算子属性:
ge::graphStatus ExampleGetTransDataAttr(TilingContext *context) { // 获取所有属性 const RuntimeAttrs *attrs = context->GetAttrs(); ASSERT_NOT_NULL(attrs); // 按照在原型定义中的顺序,使用index获取属性,index从0开始计数 const char *src_format = attrs->GetAttrPointer<char>(0); // 获取src_format,src_format是第一个属性,因此index为0 const char *dst_format = attrs->GetAttrPointer<char>(1); // 获取dst_format,dst_format是第二个属性,因此index为1 const int64_t group = attrs->GetAttrPointer<int64_t>(2); // 获取group,group是第三个属性,因此index为2 return ge::GRAPH_SUCCESS; }
通过index而不是字符串name来索引输入输出,对于带有OPTIONAL、DYNAMIC类型输入的算子,可能出现实例化后,单纯通过index无法索引到具体输入的问题,以DynamicRNNV3算子为例:
namespace ops { class DynamicRNNV3 : public OpDef { public: explicit DynamicRNNV3(const char *name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) ... this->Input("w") .ParamType(REQUIRED) ... this->Input("b") .ParamType(REQUIRED) ... this->Input("seq_length") .ParamType(OPTIONAL) ... this->Input("init_h") .ParamType(OPTIONAL) ... this->Input("init_c") .ParamType(OPTIONAL) ... this->Input("wci") .ParamType(OPTIONAL) ... this->Input("wcf") .ParamType(OPTIONAL) ... this->Input("mask") .ParamType(OPTIONAL) ... this->Input("mask") .ParamType(OPTIONAL) ... this->Input("project") .ParamType(OPTIONAL) ... ... } }; OP_ADD(DynamicRNNV3); } // namespace ops
由于DynamicRNNV3算子有连续的多个optional输入,这导致init_h及其后面的输入的实例化后index都是不确定的,对于这种类型的算子,可以通过GetOptionalInputShape传入原型对应的index来获取对应的输入shape等数据,以InferShape为例:
ge::graphStatus InferShapeForDynamicRNNV3(InferShapeContext *context) { // 对于前两个输入,不受到optional或dynamic的影响,可以按照常规方法获取输入shape auto x_shape = context->GetInputShape(0); auto w_shape = context->GetInputShape(1); if (x_shape == nullptr || w_shape == nullptr) { return ge::GRAPH_FAILED; } int64_t state_size = 0; // 在原型定义上,project是第11个输入(从0开始计数) constexpr int64_t kProjectInputIndex = 11; // 受到前面optional输入影响的,project实例化后输入的index是不确定的,通过GetOptionalInputShape来获取对应的输入shape, // GetOptionalInputShape的入参为原型上对应的index auto project_shape = context->GetOptionalInputShape(kProjectInputIndex); if (project_shape != nullptr) { if (project_shape->GetDimNum() < 2) { return ge::GRAPH_FAILED; } state_size = project_shape->GetDim(1); } // 更多的infershape逻辑... return ge::GRAPH_SUCCESS; }
对于dynamic类型的输入,实例化后的输入可能是一到多个,对于此类输入,获取方式为:
// ir_index:此输入在原型定义中的index,从0开始计数 // relative_index:该输入实例化后的相对index,从0开始计数,例如某个DYNAMIC_INPUT实例化了3个,要取第二个,那么relatvie_index = 1 auto shape = context->GetDynamicInputShape(ir_index, relative_index);
本节举例的获取optional、dynamic输入的方式,在InferShape、Tiling函数中均可以调用。
数据依赖
一般来说,具备输入shape后,算子可以通过InferShape推导出输出shape。然而部分算子在InferShape时,需要依赖某个输入的具体值才可以进行,这类算子被称为“数据依赖算子”,对应的输入被称为“数据依赖输入”。以Reshape算子为例,其依据shape输入的描述,对输入的shape做调整,因此Reshape算子依赖shape输入的值。这类算子需要在原型定义时通过ValueDepend接口声明对应的输入为数据依赖输入。
namespace ops { class Reshape : public OpDef { public: explicit Reshape(const char *name) : OpDef(name) { ... this->Input("shape") .ParamType(REQUIRED) ... .ValueDepend(REQUIRED) // 声明 ReShape算子的shape输入为数据依赖输入 ... } }; OP_ADD(Reshape); } // namespace ops
根据第1个输入(shape输入)的值,Reshape算子将第0个输入(x输入)的shape做变换,并输出到其第0个输出(y输出)上。Reshape的InferShape实现为:
ge::graphStatus InferShapeForReshape(InferShapeContext *context) { const gert::Shape *x_shape = context->GetInputShape(0); // 获取第0个输入的shape const gert::Tensor *shape_tensor = context->GetInputTensor(1); // 获取第1个输入的tensor gert::Shape *output_shape = context->GetOutputShape(0); if (x_shape == nullptr || shape_tensor == nullptr || output_shape == nullptr) { // 防御式编程,不应该出现的场景,打印错误并返回失败 return ge::GRAPH_FAILED; } auto reshape_size = static_cast<int32_t>(shape_tensor->GetShapeSize()); if (reshape_size < 1) { // 防御式编程,不应该出现的场景,打印错误并返回失败 return ge::GRAPH_FAILED; } // 根据原型信息,Reshape的shape输入支持INT32与INT64两类,根据不同的类型进入对应的模板函数中做真正的shape变换操作 if (shape_tensor->GetDataType() == ge::DT_INT32) { int32_t *reshape_data = shape_tensor->GetData<int32_t>(); return ReshapeInferShapeImpl<int32_t>(reshape_data, *x_shape, *output_shape, reshape_size); } else { int64_t *reshape_data = shape_tensor->GetData<int64_t>(); return ReshapeInferShapeImpl<int64_t>(reshape_data, *x_shape, *output_shape, reshape_size); } }
- 只有声明过数据依赖的输入,才可以在InferShape时调用GetInputTensor等获取tensor的接口获取其对应的tensor数据。若对一个未声明数据依赖的输入调用GetInputTensor等获取tensor的接口,只能在tensor中获取到正确的shape、format、datatype信息,无法获取到真实的tensor数据地址(获取到的地址为nullptr)。
- 从tensor中获取tensor_data时(GetData<int32_t>或GetData<int64_t>),使用者需要保证获取的数据类型是正确的,否则行为是未定义的。