内部关联接口
在进行算子原型注册、Tiling实现、shape推导过程中,开发者会使用上文的原型定义、Tiling实现、算子推导等外部开放接口,这些外部开放接口中会调用到的辅助数据结构和接口,称之为内部关联接口。开发者不会直接调用内部关联接口,此处仅作简单介绍。
OpImplRegisterV2类
本类作为注册接口类,提供了一系列算子原型注册接口,供用户注册指定算子类型的Tiling函数、Infershape函数、私有属性等信息。用户调用原型注册接口(OP_ADD)进行原型注册时会间接使用到该类。
接口名 |
接口功能 |
---|---|
InferShape |
注册算子的InferShape函数。 |
InferShapeRange |
注册算子的InferShapeRange函数。 |
InferDataType |
注册算子的InferDataType函数。 |
Tiling |
注册算子的Tiling函数。 |
GenSimplifiedKey |
生成二进制简化匹配key。 |
PrivateAttr |
用于注册算子的私有属性。 |
TilingParse |
注册算子的TilingParse函数。 |
InputsDataDependency |
设置算子计算依赖第几个输入tensor的值。 |
OpExecuteFunc |
单个算子包含多kernel组合执行逻辑的场景下,算子可以通过该接口设置算子级的回调函数,回调函数内实现多kernel的下发。该功能为预留特性,暂不支持。 |
HostInputs |
当算子输入中存在标量输入时,需要传入host侧地址。该接口用于标记算子的第几个输入的地址是host侧地址。 |
TilingInputsDataDependency |
标记Tiling计算时,需要依赖算子第几个输入tensor的值。 |
OutputShapeDependOnCompute |
注册算子哪几个输出的shape依赖算子的计算结果。 |
KernelContext类
本类是对底层数据结构KernelRunContext的包装,包含kernel执行时所需的算子输入、输出个数信息以及输入输出数据信息、compute_node_info和kernel_extend_info。
接口名 |
接口功能 |
---|---|
GetInput |
获取输入的Chain指针。Chain是一个可以用来保存任意类型数据的类。 |
GetInputNum |
获取kernel的输入数量。 |
GetInputPointer |
获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。 |
GetInputStrPointer |
获取输入字符串的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。 |
GetInputValue |
获取输入数据的值,本函数首先获取输入Chain,然后从输入Chain中获取值。 |
GetOutput |
获取输出的Chain指针。 |
GetOutput2 |
获取输出的Chain指针。 |
GetOutputNum |
获取kernel的输出数量。 |
GetOutputPointer |
获取输出数据的指针,本函数首先获取输出Chain,然后从Chain中获取指针。 |
GetContext |
获取底层的context结构体。 |
GetKernelExtend |
获取kernel扩展信息的指针。 |
GetComputeNodeExtend |
获取计算节点信息的指针。 |
MutableInput |
获取输入的Chain指针。 |
MutableInputPointer |
获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针。 |
IsInlineSize |
根据数据的长度判断一个数据是否会被inline存储,所谓inline存储是指此数据保存在context中时不需要单独分配内存。 |
TilingParseContext类
用于TilingParse所需的信息保存在本类中,本类为编写算子的TilingParse函数时提供上下文信息,在TilingParse时可以从本类中获取所需的信息。
接口名 |
接口功能 |
---|---|
GetCompiledJson |
获取算子编译产生的json字符串。 |
GetCompiledInfo |
获取CompiledInfo实例。 |
GetPlatformInfo |
获取fe::PlatFormInfos指针。 |
Chain类
Chain是一个可以用来保存任意类型数据的类。通过其Set接口保存数据,通过GetPointer或者GetValue方法获取保存的数据。
接口名 |
接口功能 |
---|---|
GetPointer |
获取指向Chain中保存数据的指针。 |
GetValue |
获取Chain中保存的数据的值。 |
Set |
将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除。 |
SetWithDefaultDeleter |
将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除。 |
HasDeleter |
判断当前Chain中保存的数据是否有deleter。 |
ContinuousBuffer类
ContinuousBuffer类用于连续存储任意类型的数据,每个数据被转成uint8_t数组进行存储,每个存储单元称为一个buffer。ContinuousBuffer使用两个属性来描述被存储的数据,分别是存储buffer的个数以及每个buffer对应的内存偏移量offsets_。每个buffer在内存上是连续的。
接口名 |
接口功能 |
---|---|
GetNum |
获取buffer的数量。 |
GetTotalLength |
获取本实例的总长度。 |
Get |
获取buffer指针、长度信息。 |
KernelExtendInfo类
本类是用于保存的kernel运行时的额外信息,包括kernel的名字、类型、用于profiling注册的kernel type的index以及计算节点名的index。
接口名 |
接口功能 |
---|---|
GetKernelName |
获取kernel name。 |
SetKernelName |
设置kernel name。 |
GetKernelType |
获取kernel type。一个算子可以对应多个kernel,也就对应多个kernel type。 |
SetKernelType |
设置kernel type。 |
SetKernelTypeIdx |
设置kernel_type_idx_,用于profiling。profiling时设置的kernel type字符串对应的idx,用idx代替kernel type字符串作为标识,提升速度。 |
SetNodeNameIdx |
设置compute_node_name_idx_,用于profiling,profiling时,用idx代替node name字符串作为标识,提升速度。 |
GetNodeNameIdx |
获取compute_node_name_idx_,用于profiling。 |
GetKernelTypeIdx |
获取kernel_type_idx_,用于profiling。 |
TensorOperateType
enum TensorOperateType { kGetTensorAddress, ///< 获取Tensor的地址 kFreeTensor, ///< 释放Tensor kPlusShareCount, ///< 共享Tensor kTensorOperateType };
StructSizeInfoBase类
StructSizeInfoBase类用来存放、获取Ascend C高阶API和用户自定义的tiling结构体大小,通过单例实现一个结构体信息库。
函数名称 |
含义 |
---|---|
GetInstance |
获取StructSizeInfoBase类的单例 |
SetStructSize |
设置高阶API/用户自定义tiling结构大小 |
GetStructSize |
获取高阶API/用户自定义tiling结构大小 |
TilingDataStructBase类
TilingDataStructBase的类,用于记录Tiling结构体构造过程信息。提供以下接口:
函数名称 |
含义 |
---|---|
GetInstance |
获取TilingDataStructBase类的单例 |
RecordTilingStruct |
用于框架检查不同算子注册同名不同结构tiling结构体情况,若出现则会打印warning信息。 |
CTilingDataClassFactory
CTilingDataClassFactory类用于注册Ascend C高阶API和用户自定义tiling结构,通过单例实现的工厂类,在TilingData结构注册中REGISTER_TILING_DATA_CLASS中使用。
函数名称 |
含义 |
---|---|
GetInstance |
获取CTilingDataClassFactory类的单例 |
RegisterTilingData |
注册op_type的tiling结构constructor至CTilingDataClassFactory工厂类 |
CreateTilingDataInstance |
根据算子名获取tiling结构 |
数值计算校验相关接口
框架内部使用的数值计算校验相关接口。
函数名称 |
含义 |
---|---|
MulOverflow |
判断乘法数值运算是否溢出。其中溢出的判断条件为乘积后的数值大于ret类型TRet所能表示的最大值。 |
AddOverflow |
判断加法数值运算是否溢出。其中溢出的判断条件为相加后的数值大于ret类型TRet所能表示的最大值。 |
Compat |
判断入参的数值是否超过指定类型T所能表示的数值范围。 |
OpExecuteContext类
本类用于保存图模式下调用单算子执行API场景下的上下文。
函数名称含义 |
含义 |
---|---|
GetInputTensor |
获取指定索引输入的Tensor |
GetOptionalInputTensor |
获取指定索引的可选输入的Tensor |
GetDynamicInputTensor |
获取指定索引的动态个数输入的Tensor |
GetRequiredInputTensor |
获取指定索引的必选输入的Tensor |
GetOutputTensor |
获取指定索引输出的Tensor |
GetDynamicOutputTensor |
获取指定索引的动态个数输出的Tensor |
GetRequiredOutputTensor |
获取指定索引的必选输出的Tensor |
GetStream |
获取下发算子的流信息 |
GetOpExecuteFunc |
获取多kernel算子的执行回调函数 |
MallocWorkspace |
分配workspace内存 |
FreeWorkspace |
释放workspace内存 |
GetDeterministic |
获取确定性计算配置选项 |
GetAllowHf32 |
获取allow_hf32配置选项 |
GetPrecisionMode |
获取精度模式配置选项 |
OpDef相关接口
原型注册相关类的一些非对外开放接口,包括OpDef, OpParamDef, OpAICoreConfig等,下述接口均为框架生成相关工程所需,用户无需关心。
类名 |
接口名 |
接口功能 |
---|---|---|
OpAICoreDef |
GetTiling |
获取Tiling信息。 |
GetCheckSupport |
获取参数校验。 |
|
GetOpSelectFormat |
获取推导函数。 |
|
GetOpSupportInfo |
获取算子切分信息函数。 |
|
GetOpSpecInfo |
获取算子额外信息供FE解析函数。 |
|
GetParamGeneralize |
获取算子泛化函数。 |
|
SetOpSupportInfo |
设置算子切分信息函数。 |
|
SetOpSpecInfo |
设置算子额外信息供FE解析函数。 |
|
SetParamGeneralize |
设置算子泛化函数。 |
|
OpDef |
GetInferShape |
获取Shape推导函数。 |
GetInferShapeRange |
获取ShapeRange推导函数。 |
|
GetInferDataType |
获取DataType推导函数。 |
|
OpDefFactory |
OpDefRegister |
注册算子。 |
OpCheckFuncHelper |
- |
算子相关函数检查,用来检测输入、输出shape等,如算子参数校验,推导回调函数等。 |
Tiling定义辅助接口
以下接口为Tiling定义辅助接口。AscendC提供了一系列Tiling类型定义宏,包含BEGIN_TILING_DATA_DEF、TILING_DATA_FIELD_DEF、TILING_DATA_FIELD_DEF_ARR、TILING_DATA_FIELD_DEF_STRUCT、REGISTER_TILING_DATA_CLASS。用户调用该系列宏会调用以下接口对tiling信息进一步处理,包含结构信息保留、字节对齐等。
// TilingDef类 void SaveToBuffer(void *pdata, size_t capacity); std::vector<FieldInfo> GetFieldInfo() const; const char *GetTilingClassName() const; size_t GetDataSize() const; void SetDataPtr(void *dataPtr); void CheckAlignAndGenPlaceHolder(const char *name, size_t typeSize); // FieldInfo类 FieldInfo(const char *dtype, const char *name) FieldInfo(const char *dtype, const char *name, size_t arrSize) FieldInfo(const char *dtype, const char *name, const char *structType,size_t structSize)