下载
中文
注册

内部关联接口

在进行算子原型注册、Tiling实现、shape推导过程中,开发者会使用上文的原型定义Tiling实现算子推导等外部开放接口,这些外部开放接口中会调用到的辅助数据结构和接口,称之为内部关联接口。开发者不会直接调用内部关联接口,此处仅作简单介绍。

OpImplRegisterV2类

本类作为注册接口类,提供了一系列算子原型注册接口,供用户注册指定算子类型的Tiling函数、Infershape函数、私有属性等信息。用户调用原型注册接口(OP_ADD)进行原型注册时会间接使用到该类。

表1 OpImplRegisterV2类成员函数

接口名

接口功能

InferShape

注册算子的InferShape函数。

InferShapeRange

注册算子的InferShapeRange函数。

InferDataType

注册算子的InferDataType函数。

Tiling

注册算子的Tiling函数。

GenSimplifiedKey

生成二进制简化匹配key。

PrivateAttr

用于注册算子的私有属性。

TilingParse

注册算子的TilingParse函数。

InputsDataDependency

设置算子计算依赖第几个输入tensor的值。

OpExecuteFunc

单个算子包含多kernel组合执行逻辑的场景下,算子可以通过该接口设置算子级的回调函数,回调函数内实现多kernel的下发。该功能为预留特性,暂不支持。

HostInputs

当算子输入中存在标量输入时,需要传入host侧地址。该接口用于标记算子的第几个输入的地址是host侧地址。

TilingInputsDataDependency

标记Tiling计算时,需要依赖算子第几个输入tensor的值。

OutputShapeDependOnCompute

注册算子哪几个输出的shape依赖算子的计算结果。

KernelContext类

本类是对底层数据结构KernelRunContext的包装,包含kernel执行时所需的算子输入、输出个数信息以及输入输出数据信息、compute_node_info和kernel_extend_info。

表2 KernelContext类成员函数

接口名

接口功能

GetInput

获取输入的Chain指针。Chain是一个可以用来保存任意类型数据的类。

GetInputNum

获取kernel的输入数量。

GetInputPointer

获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。

GetInputStrPointer

获取输入字符串的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。

GetInputValue

获取输入数据的值,本函数首先获取输入Chain,然后从输入Chain中获取值。

GetOutput

获取输出的Chain指针。

GetOutput2

获取输出的Chain指针。

GetOutputNum

获取kernel的输出数量。

GetOutputPointer

获取输出数据的指针,本函数首先获取输出Chain,然后从Chain中获取指针。

GetContext

获取底层的context结构体。

GetKernelExtend

获取kernel扩展信息的指针。

GetComputeNodeExtend

获取计算节点信息的指针。

MutableInput

获取输入的Chain指针。

MutableInputPointer

获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。

IsInlineSize

根据数据的长度判断一个数据是否会被inline存储,所谓inline存储是指此数据保存在context中时不需要单独分配内存。

TilingParseContext类

用于TilingParse所需的信息保存在本类中,本类为编写算子的TilingParse函数时提供上下文信息,在TilingParse时可以从本类中获取所需的信息。

表3 TilingParseContext类成员函数

接口名

接口功能

GetCompiledJson

获取算子编译产生的json字符串。

GetCompiledInfo

获取CompiledInfo实例。

GetPlatformInfo

获取fe::PlatFormInfos指针。

Chain类

Chain是一个可以用来保存任意类型数据的类。通过其Set接口保存数据,通过GetPointer或者GetValue方法获取保存的数据。

表4 Chain类成员函数

接口名

接口功能

GetPointer

获取指向Chain中保存数据的指针。

GetValue

获取Chain中保存的数据的值。

Set

将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除。

SetWithDefaultDeleter

将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除。

HasDeleter

判断当前Chain中保存的数据是否有deleter。

ContinuousBuffer类

ContinuousBuffer类用于连续存储任意类型的数据,每个数据被转成uint8_t数组进行存储,每个存储单元称为一个buffer。ContinuousBuffer使用两个属性来描述被存储的数据,分别是存储buffer的个数以及每个buffer对应的内存偏移量offsets_。每个buffer在内存上是连续的。

表5 ContinuousBuffer类成员函数

接口名

接口功能

GetNum

获取buffer的数量。

GetTotalLength

获取本实例的总长度。

Get

获取buffer指针、长度信息。

KernelExtendInfo类

本类是用于保存的kernel运行时的额外信息,包括kernel的名字、类型、用于profiling注册的kernel type的index以及计算节点名的index。

表6 KernelExtendInfo类成员函数

接口名

接口功能

GetKernelName

获取kernel name。

SetKernelName

设置kernel name。

GetKernelType

获取kernel type。一个算子可以对应多个kernel,也就对应多个kernel type。

SetKernelType

设置kernel type。

SetKernelTypeIdx

设置kernel_type_idx_,用于profiling。profiling时设置的kernel type字符串对应的idx,用idx代替kernel type字符串作为标识,提升速度。

SetNodeNameIdx

设置compute_node_name_idx_,用于profiling,profiling时,用idx代替node name字符串作为标识,提升速度。

GetNodeNameIdx

获取compute_node_name_idx_,用于profiling。

GetKernelTypeIdx

获取kernel_type_idx_,用于profiling。

TensorOperateType

enum TensorOperateType {
  kGetTensorAddress,  ///< 获取Tensor的地址
  kFreeTensor,        ///< 释放Tensor
  kPlusShareCount,    ///< 共享Tensor
  kTensorOperateType
};

StructSizeInfoBase类

StructSizeInfoBase类用来存放、获取Ascend C高阶API和用户自定义的tiling结构体大小,通过单例实现一个结构体信息库。

表7 StructSizeInfoBase类成员函数

函数名称

含义

GetInstance

获取StructSizeInfoBase类的单例

SetStructSize

设置高阶API/用户自定义tiling结构大小

GetStructSize

获取高阶API/用户自定义tiling结构大小

TilingDataStructBase类

TilingDataStructBase的类,用于记录Tiling结构体构造过程信息。提供以下接口:

表8 TilingDataStructBase类成员函数

函数名称

含义

GetInstance

获取TilingDataStructBase类的单例

RecordTilingStruct

用于框架检查不同算子注册同名不同结构tiling结构体情况,若出现则会打印warning信息。

CTilingDataClassFactory

CTilingDataClassFactory类用于注册Ascend C高阶API和用户自定义tiling结构,通过单例实现的工厂类,在TilingData结构注册中REGISTER_TILING_DATA_CLASS中使用。

表9 CTilingDataClassFactory成员函数

函数名称

含义

GetInstance

获取CTilingDataClassFactory类的单例

RegisterTilingData

注册op_type的tiling结构constructor至CTilingDataClassFactory工厂类

CreateTilingDataInstance

根据算子名获取tiling结构

数值计算校验相关接口

框架内部使用的数值计算校验相关接口。

表10 接口说明

函数名称

含义

MulOverflow

判断乘法数值运算是否溢出。其中溢出的判断条件为乘积后的数值大于ret类型TRet所能表示的最大值。

AddOverflow

判断加法数值运算是否溢出。其中溢出的判断条件为相加后的数值大于ret类型TRet所能表示的最大值。

Compat

判断入参的数值是否超过指定类型T所能表示的数值范围。

OpExecuteContext类

本类用于保存图模式下调用单算子执行API场景下的上下文。

表11 OpExecuteContext类成员函数

函数名称含义

含义

GetInputTensor

获取指定索引输入的Tensor

GetOptionalInputTensor

获取指定索引的可选输入的Tensor

GetDynamicInputTensor

获取指定索引的动态个数输入的Tensor

GetRequiredInputTensor

获取指定索引的必选输入的Tensor

GetOutputTensor

获取指定索引输出的Tensor

GetDynamicOutputTensor

获取指定索引的动态个数输出的Tensor

GetRequiredOutputTensor

获取指定索引的必选输出的Tensor

GetStream

获取下发算子的流信息

GetOpExecuteFunc

获取多kernel算子的执行回调函数

MallocWorkspace

分配workspace内存

FreeWorkspace

释放workspace内存

GetDeterministic

获取确定性计算配置选项

GetAllowHf32

获取allow_hf32配置选项

GetPrecisionMode

获取精度模式配置选项

OpDef相关接口

原型注册相关类的一些非对外开放接口,包括OpDef, OpParamDef, OpAICoreConfig等,下述接口均为框架生成相关工程所需,用户无需关心。

表12 OpDef相关类及其成员函数

类名

接口名

接口功能

OpAICoreDef

GetTiling

获取Tiling信息。

GetCheckSupport

获取参数校验。

GetOpSelectFormat

获取推导函数。

GetOpSupportInfo

获取算子切分信息函数。

GetOpSpecInfo

获取算子额外信息供FE解析函数。

GetParamGeneralize

获取算子泛化函数。

SetOpSupportInfo

设置算子切分信息函数。

SetOpSpecInfo

设置算子额外信息供FE解析函数。

SetParamGeneralize

设置算子泛化函数。

OpDef

GetInferShape

获取Shape推导函数。

GetInferShapeRange

获取ShapeRange推导函数。

GetInferDataType

获取DataType推导函数。

OpDefFactory

OpDefRegister

注册算子。

OpCheckFuncHelper

-

算子相关函数检查,用来检测输入、输出shape等,如算子参数校验,推导回调函数等。

Tiling定义辅助接口

以下接口为Tiling定义辅助接口。AscendC提供了一系列Tiling类型定义宏,包含BEGIN_TILING_DATA_DEF、TILING_DATA_FIELD_DEF、TILING_DATA_FIELD_DEF_ARR、TILING_DATA_FIELD_DEF_STRUCT、REGISTER_TILING_DATA_CLASS。用户调用该系列宏会调用以下接口对tiling信息进一步处理,包含结构信息保留、字节对齐等。

// TilingDef类
void SaveToBuffer(void *pdata, size_t capacity);
std::vector<FieldInfo> GetFieldInfo() const;
const char *GetTilingClassName() const;
size_t GetDataSize() const;
void SetDataPtr(void *dataPtr);
void CheckAlignAndGenPlaceHolder(const char *name, size_t typeSize);
// FieldInfo类
FieldInfo(const char *dtype, const char *name)
FieldInfo(const char *dtype, const char *name, size_t arrSize)
FieldInfo(const char *dtype, const char *name, const char *structType,size_t structSize)