下载
EN
注册

原理

概述

算子原型定义规定了在昇腾AI处理器上可运行算子的约束,主要体现算子的数学含义,包含定义算子输入、输出和属性信息,基本参数的校验和shape的推导,原型定义的信息会被注册到GE的算子原型库中。网络模型生成时,GE会调用算子原型库的校验接口进行基本参数的校验,校验通过后,会根据原型库中的推导函数推导每个节点的输出shape与dtype,进行输出tensor的静态内存的分配。

算子原型库在整个网络模型生成流程的作用如图1所示。

图1 算子原型在GE注册的流程

其中算子注册包括OP注册、InferShape与Verify的注册,注册的时候算子类型(OpType)作为Key。

  1. 首先GE接收到第三方框架的原始网络模型,并进行初始化,网络模型的拓扑图简称为图。
  2. 算子原型库管理模块从CANN算子库的“opp/built-in/op_proto/lib”目录下加载对应操作系统和架构的算子原型库so文件。
  3. 算子原型库管理模块根据so文件中的信息在OperatorFactory中进行算子信息注册,包括算子基本信息注册、InferShape注册、Verify注册,这三部分分别以算子类型(OpType)作为key注册到三个map文件中进行保存。
  4. 图准备阶段,GE向Graph发送调用InferShape函数与Verify函数的请求。其中InferShape函数用于进行输出shape的推导,从而进行静态内存的分配;Verify函数用于进行参数的基本校验。
  5. Graph会遍历图中所有节点。
  6. 每个节点都会向OpDesc发送调用InferShape与Verify函数的请求。
  7. OpDesc从OperatorFactory中根据OpType取出对应的InferShape函数与Verify函数。
  8. OpDesc执行Verify函数进行校验,如果校验成功,则继续往下执行;如果校验失败,则直接返回。
  9. OpDesc执行InferShape函数,进行输出tensor的shape推导。
  10. OpDesc向GE返回InferShape的结果,GE后续根据InferShape结果分配输出tensor的静态内存。
  11. GE进行其他处理。

InferShape流程介绍

介绍InferShape流程前先了解如下背景知识。

  • 图(Graph)是承载编译优化阶段的基本结构,图由节点OP(Operator)连接而成,每个节点间的边代表执行时要传递的数据,即张量(Tensor),每个Tensor都包含三个主要属性:dtype(数据类型)、shape(形状)、format(数据排布格式)。
  • 图中每个OP都可能会有多个输入与多个输出,输入与输出的个数分别由输入TensorDesc与输出TensorDesc的个数来表示,TensorDesc的介绍请参见表1
  • 图中的两个OP通过某条边直连时,前导OP的输出TensorDesc与后继OP的输入TensorDesc应该完全一致。另外,每个OP在获取了输入TensorDesc后,可以在编译阶段就推导出所有输出TensorDesc(例外场景:有些OP需要根据输入的实际值来推导输出TensorDesc,这类OP如果依赖的输入不是常量节点就无法推导出确切的输出TensorDesc)。

由上述背景介绍可知,只要全图所有首节点的TensorDesc确定了,就可以逐个向下传播,再由算子自身实现的Shape推导能力,就可以将全图所有OP的输入输出TensorDesc推导出来,这就是InferShape的推导流程。

InferShape有以下注意点:

  • 可进行InferShape的前提是图中所有首节点的TensorDesc都可确定,因此对于输入数据是动态shape的网络,静态编译的InferShape是无法生效的。
  • GE的InferShape流程负责推导TensorDesc中的dtype与shape,推导结束后,全图的dtype与shape的规格就完全连续了,如果生成网络模型时产生的GE Dump图“ge_proto_000000xx_after_infershape.txt”中存在dtype与shape规格不连续的情况,说明InferShape处理有错误。

    若要生成GE的Dump图,生成网络模型时需要设置如下环境变量:

    export DUMP_GE_GRAPH=1

  • InferShape流程里,GE将前一个算子的输出Tensor刷新到下一个算子的输入Tensor,算子根据输入推导并更新输出Tensor,若算子未自定义实现InferShape,则保留原图的输出Tesnor。
  • 某些算子在进行输出TensorDesc的推导时依赖算子输入的实际值,此时就需要算子的输入必须是常量节点(Const)。如果算子的输入不为常量节点,可能会造成推导输出Tensor的shape出错。

实现文件说明

开发者进行算子原型定义开发时需要实现如下两个文件:
  • 在头文件中进行算子IR原型的注册
  • 在cc文件中进行校验函数与shape推导函数的实现