Shape推导等函数实现

假设下图是我们需要使用的网络模型,您可能会想直接逐个算子调用,根据输入Tensor得到输出Tensor就可以完成网络的运行,但在实际的网络模型生成过程中,会先进行Tensor shape以及dtype的推导。这样可以让我们在图执行之前,就知道各Tensor的数据类型和形状,提前校验其正确性;同时提前推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息,算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。

下面的网络模型经过infershape和dtype推导之后,可以得到灰色底纹框中的推导信息:

图1 shape推导示意图

以AddCustom算子为例,InferShape的实现如下所示。该样例中输出Tensor的描述信息与输入Tensor的描述信息相同,所以直接将任意一个输入Tensor的描述赋给输出Tensor即可。

namespace ge {
static graphStatus InferShape(gert::InferShapeContext* context)
{
    const auto inputShape = context->GetInputShape(0);
    auto outputShape = context->GetOutputShape(0);
    *outputShape = *inputShape;
    return GRAPH_SUCCESS;
}
} // namespace ge