kernel侧算子实现
在矢量编程章节,主要介绍了简单的固定shape矢量算子的kernel侧实现,算子的shape、数据类型都是固定的;在实际的算子开发场景中,这些信息是支持动态变化的,场景会更加灵活和复杂。本节将介绍更为复杂场景的kernel侧算子实现,而对kernel侧核函数定义和算子类实现的基本流程,不再过多描述。
动态shape算子的kernel侧实现
本节以实现一个支持动态shape的Add算子为例,重点进行动态shape与固定shape差异点的介绍。
动态Shape场景下,输入的Shape是未知的。一些与输入Shape相关的变量(比如每次搬运的块大小等),也就是TilingData,需要通过Tiling函数计算出来。Tiling函数的相关介绍请参考Tiling实现。
Tiling函数是在host侧实现的代码,与之相对应的,在kernel侧算子实现代码中,您只需通过调用GET_TILING_DATA即可获取TilingData结构体参数,并使用具体的参数进行后续的计算。
通过Add算子举例来说明,固定shape和动态shape kernel侧算子实现的区别。
- 矢量编程章节中固定shape的算子样例中,TILE_NUM(每个核上总计算数据分块个数)、BLOCK_LENGTH(每个核上总计算数据大小)、TILE_LENGTH(每个分块大小)等是固定的数值。
constexpr int32_t TOTAL_LENGTH = 8 * 2048; // total length of data constexpr int32_t USE_CORE_NUM = 8; // num of core used constexpr int32_t BLOCK_LENGTH = TOTAL_LENGTH / USE_CORE_NUM; // length computed of each core constexpr int32_t TILE_NUM = 8; // split data into 8 tiles for each core constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue constexpr int32_t TILE_LENGTH = BLOCK_LENGTH / TILE_NUM / BUFFER_NUM; // each tile length is seperated to 2 part, due to double buffer
- 动态shape的实现中,需要在核函数中通过GET_TILING_DATA获取Tiling参数,再基于Tiling参数计算得到singleCoreSize(每个核上总计算数据大小)、tileNum(每个核上总计算数据分块个数)、singleTileLength(每个分块大小)等变量。注意,对应的算子host实现中需要定义TilingData结构体,实现并注册计算TilingData的Tiling函数。具体请参考Tiling实现。
核函数中调用GET_TILING_DATA获取TilingData的样例如下:
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); KernelAdd op; op.Init(x, y, z, tilingData.totalLength, tilingData.tileNum); if (TILING_KEY_IS(1)) { op.Process(); } }
算子类的Init函数中,使用获取到的TilingData计算得到singleCoreSize、tileNum、singleTileLength等变量的样例如下。
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength, uint32_t tileNum) { ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); this->blockLength = totalLength / GetBlockNum(); this->tileNum = tileNum; ASSERT(tileNum != 0 && "tile num can not be zero!"); this->tileLength = this->blockLength / tileNum / BUFFER_NUM; // ... }
核函数内推导输入数据和格式
算子工程在核函数内提供了DTYPE_<Arg>、ORIG_DTYPE_<Arg>、FORMAT_<Arg>三种宏用于推导核函数入参的数据类型、原始数据类型和数据格式。其中<Arg>会自动大写。样例如下:
template<class T> func() {} extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) { DTYPE_X temp; func<DTYPE_Z>(); if (FORMAT_Y == FORMAT_ND) { ... } }
非对齐shape算子的kernel侧实现
非对齐shape的实现中,同样需要在核函数中通过GET_TILING_DATA获取Tiling参数,包括formerNum(分配到较多数据量的核心数), tailNum(分配到较少数据量的核心数), formerLength(大块的长度), tailLength(小块的长度),alignNum(需要对齐到的最小数据量)。
extern "C" __global__ __aicore__ void add_custom_unalign(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); KernelAdd op; op.Init(x, y, z, tilingData.formerNum, tilingData.tailNum, tilingData.formerLength, tilingData.tailLength, tilingData.alignNum); if (TILING_KEY_IS(1)) { op.Process(); } }
算子类的Init函数中,使用获取到的TilingData计算得到每个核上的偏移量、每个分块大小的样例如下。
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t formerNum, uint32_t tailNum, uint32_t formerLength, uint32_t tailLength, uint32_t alignNum) { // 由于不同的核有不同的数据量,所以不同的核有不同的offset if (GetBlockIdx() < formerNum) { this->tileLength = formerLength; xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x + formerLength * GetBlockIdx(), formerLength); yGm.SetGlobalBuffer((__gm__ DTYPE_Y *)y + formerLength * GetBlockIdx(), formerLength); zGm.SetGlobalBuffer((__gm__ DTYPE_Z *)z + formerLength * GetBlockIdx(), formerLength); } else { this->tileLength = tailLength; xGm.SetGlobalBuffer( (__gm__ DTYPE_X *)x + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength); yGm.SetGlobalBuffer( (__gm__ DTYPE_Y *)y + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength); zGm.SetGlobalBuffer( (__gm__ DTYPE_Z *)z + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength); } ASSERT(alignNum != 0 && "align num can not be zero!"); // 切分后有些数量不满足32B对齐,所以需要对length向上对齐到32B的数据量 pipe.InitBuffer(inQueueX, BUFFER_NUM, (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half)); pipe.InitBuffer(inQueueY, BUFFER_NUM, (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half)); pipe.InitBuffer(outQueueZ, BUFFER_NUM, (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half)); }