文档
注册

kernel侧算子实现

矢量编程章节,主要介绍了简单的固定shape矢量算子的kernel侧实现,算子的shape、数据类型都是固定的;在实际的算子开发场景中,这些信息是支持动态变化的,场景会更加灵活和复杂。本节将介绍更为复杂场景的kernel侧算子实现,而对kernel侧核函数定义和算子类实现的基本流程,不再过多描述。

动态shape算子的kernel侧实现

本节以实现一个支持动态shape的Add算子为例,重点进行动态shape与固定shape差异点的介绍。

动态Shape场景下,输入的Shape是未知的。一些与输入Shape相关的变量(比如每次搬运的块大小等),也就是TilingData,需要通过Tiling函数计算出来。Tiling函数的相关介绍请参考Tiling实现

Tiling函数是在host侧实现的代码,与之相对应的,在kernel侧算子实现代码中,您只需通过调用GET_TILING_DATA即可获取TilingData结构体参数,并使用具体的参数进行后续的计算。

通过Add算子举例来说明,固定shape和动态shape kernel侧算子实现的区别。

  • 矢量编程章节中固定shape的算子样例中,TILE_NUM(每个核上总计算数据分块个数)、BLOCK_LENGTH(每个核上总计算数据大小)、TILE_LENGTH(每个分块大小)等是固定的数值。
    constexpr int32_t TOTAL_LENGTH = 8 * 2048;                            // total length of data
    constexpr int32_t USE_CORE_NUM = 8;                                   // num of core used
    constexpr int32_t BLOCK_LENGTH = TOTAL_LENGTH / USE_CORE_NUM;         // length computed of each core
    constexpr int32_t TILE_NUM = 8;                                       // split data into 8 tiles for each core
    constexpr int32_t BUFFER_NUM = 2;                                     // tensor num for each queue
    constexpr int32_t TILE_LENGTH = BLOCK_LENGTH / TILE_NUM / BUFFER_NUM; // each tile length is seperated to 2 part, due to double buffer
  • 动态shape的实现中,需要在核函数中通过GET_TILING_DATA获取Tiling参数,再基于Tiling参数计算得到singleCoreSize(每个核上总计算数据大小)、tileNum(每个核上总计算数据分块个数)、singleTileLength(每个分块大小)等变量。注意,对应的算子host实现中需要定义TilingData结构体,实现并注册计算TilingData的Tiling函数。具体请参考Tiling实现
    核函数中调用GET_TILING_DATA获取TilingData的样例如下:
    extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
    {
        GET_TILING_DATA(tilingData, tiling);
        KernelAdd op;
        op.Init(x, y, z, tilingData.totalLength, tilingData.tileNum);
        if (TILING_KEY_IS(1)) {
            op.Process();
        }
    }

    算子类的Init函数中,使用获取到的TilingData计算得到singleCoreSize、tileNum、singleTileLength等变量的样例如下。

    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength, uint32_t tileNum)
    {
        ASSERT(GetBlockNum() != 0 && "block dim can not be zero!");
        this->blockLength = totalLength / GetBlockNum();
        this->tileNum = tileNum;
        ASSERT(tileNum != 0 && "tile num can not be zero!");
        this->tileLength = this->blockLength / tileNum / BUFFER_NUM;
        // ...
    
    }

核函数内推导输入数据和格式

算子工程在核函数内提供了DTYPE_<Arg>、ORIG_DTYPE_<Arg>、FORMAT_<Arg>三种宏用于推导核函数入参的数据类型、原始数据类型和数据格式。其中<Arg>会自动大写。样例如下:

template<class T> func() {}
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
{
    DTYPE_X temp;
    func<DTYPE_Z>();
    if (FORMAT_Y == FORMAT_ND) {
        ...
    }
}

非对齐shape算子的kernel侧实现

非对齐shape的实现中,同样需要在核函数中通过GET_TILING_DATA获取Tiling参数,包括formerNum(分配到较多数据量的核心数), tailNum(分配到较少数据量的核心数), formerLength(大块的长度), tailLength(小块的长度),alignNum(需要对齐到的最小数据量)。

核函数中调用GET_TILING_DATA获取TilingData的样例如下:
extern "C" __global__ __aicore__ void add_custom_unalign(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelAdd op;
    op.Init(x, y, z, tilingData.formerNum, tilingData.tailNum, tilingData.formerLength,
        tilingData.tailLength, tilingData.alignNum);
    if (TILING_KEY_IS(1)) {
        op.Process();
    }
}

算子类的Init函数中,使用获取到的TilingData计算得到每个核上的偏移量、每个分块大小的样例如下。

__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t formerNum,
        uint32_t tailNum, uint32_t formerLength, uint32_t tailLength, uint32_t alignNum)
    {
        // 由于不同的核有不同的数据量,所以不同的核有不同的offset
        if (GetBlockIdx() < formerNum) {
            this->tileLength = formerLength;
            xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x + formerLength * GetBlockIdx(), formerLength);
            yGm.SetGlobalBuffer((__gm__ DTYPE_Y *)y + formerLength * GetBlockIdx(), formerLength);
            zGm.SetGlobalBuffer((__gm__ DTYPE_Z *)z + formerLength * GetBlockIdx(), formerLength);
        } else {
            this->tileLength = tailLength;
            xGm.SetGlobalBuffer(
                (__gm__ DTYPE_X *)x + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
            yGm.SetGlobalBuffer(
                (__gm__ DTYPE_Y *)y + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
            zGm.SetGlobalBuffer(
                (__gm__ DTYPE_Z *)z + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
        }
        ASSERT(alignNum != 0 && "align num can not be zero!");
        // 切分后有些数量不满足32B对齐,所以需要对length向上对齐到32B的数据量
        pipe.InitBuffer(inQueueX, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
        pipe.InitBuffer(inQueueY, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
}
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词