本章节介绍的Shape与DataType推导函数实现属于可选内容,仅在图模式下相关的调用方式时,需要提供该函数实现。
假设下图是我们需要使用的网络模型,您可能会想直接逐个算子调用,根据输入Tensor得到输出Tensor就可以完成网络的运行,但在实际的网络模型生成过程中,会先进行Tensor shape以及data type的推导。这样可以让我们在图执行之前,就知道各Tensor的数据类型和形状,提前校验其正确性;同时提前推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息,算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
下面的网络模型经过infershape和data type推导之后,可以得到灰色底纹框中的推导信息:
以AddCustom算子为例,InferShape的实现如下所示。该样例中输出Tensor的描述信息与输入Tensor的描述信息相同,所以直接将任意一个输入Tensor的描述赋给输出Tensor即可。
namespace ge { static graphStatus InferShape(gert::InferShapeContext* context) { const auto inputShape = context->GetInputShape(0); auto outputShape = context->GetOutputShape(0); *outputShape = *inputShape; return GRAPH_SUCCESS; } } // namespace ge
InferDataType的实现如下所示。该样例中输出Tensor的数据类型与输入Tensor的数据类型相同,所以直接将任意一个输入Tensor的数据类型赋给输出Tensor即可。
namespace ge { static graphStatus InferDataType(gert::InferDataTypeContext* context) { const auto inputDataType = context->GetInputDataType(0); context->SetOutputDataType(0, inputDataType); return ge::GRAPH_SUCCESS; } } // namespace ge