更多场景

动态shape场景

算子实现章节,已经介绍了简单的固定shape矢量算子的kernel侧实现,算子的shape、数据类型都是固定的;在实际的算子开发场景中,这些信息是支持动态变化的,场景会更加灵活和复杂。下文重点进行动态shape与固定shape差异点的介绍。

最主要的区别是:动态Shape场景下,输入的Shape是未知的。一些与输入Shape相关的变量(比如每次搬运的块大小等),需要通过Tiling计算出来,然后传递到kernel侧,kernel侧使用该参数进行后续的计算。

shape非对齐场景

针对一些非对齐shape,比如算子的输入shape为(1,1999),支持的数据类型为half类型,既无法对齐到一个block的大小(32B),也无法平均分配到每个核上,需要一些特殊的Tiling处理方法。本示例非对齐的AddCustom算子使用了5个tiling参数:formerNum, tailNum, formerLength, tailLength,alignNum,计算这些Tiling参数的伪代码如下:
constexpr uint32_t BLOCK_DIM = 8;
constexpr uint32_t SIZE_OF_HALF = 2;
constexpr uint32_t BLOCK_SIZE = 32;
// shape需要对齐到的最小单位
constexpr uint32_t ALIGN_NUM = BLOCK_SIZE / SIZE_OF_HALF;
...
uint8_t *GenerateTiling()
{
    ...
    // 如果是非对齐的shape,需要向上对齐到最小单位
    uint32_t totalLengthAligned = ((totalLength + ALIGN_NUM - 1) / ALIGN_NUM) * ALIGN_NUM;
    // 把所有的数据尽可能均匀地分配到每个核上,如果不能均分的话,那么会有部分核多算一个最小单位ALIGN_NUM
    // 通过模的计算,可以得到多算一个最小单位的核的数量,也可以得到少算一个最小单位的核的数量
    // eg:1999 对齐后的总数据量为2000个数,核心数为8,数据块的最小单位是16,那么:
    // 1、最小单位数据块的总数:2000 / 16 = 125
    // 2、有5个核会分到16个最小单位的数据块:125 % 8 =5,可以称之为大块
    // 3、有3个核会分到15个最小单位的数据块:8 - 5 = 3,可以称之为小块
    uint32_t formerNum = (totalLengthAligned / ALIGN_NUM) % BLOCK_DIM;
    uint32_t tailNum = BLOCK_DIM - formerNum;
    // 计算大块和小块的数据量
    uint32_t formerLength = ((totalLengthAligned / BLOCK_DIM + ALIGN_NUM - 1) / ALIGN_NUM) * ALIGN_NUM;
    uint32_t tailLength = (totalLengthAligned / BLOCK_DIM / ALIGN_NUM) * ALIGN_NUM;
    ...
}

相对应的,在Kernel侧,使用获取到的信息计算得到每个核上的偏移量、每个分块大小的样例如下。

__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t formerNum,
        uint32_t tailNum, uint32_t formerLength, uint32_t tailLength, uint32_t alignNum)
    {
        // 由于不同的核有不同的数据量,所以不同的核有不同的offset
        if (GetBlockIdx() < formerNum) {
            this->tileLength = formerLength;
            xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x + formerLength * GetBlockIdx(), formerLength);
            yGm.SetGlobalBuffer((__gm__ DTYPE_Y *)y + formerLength * GetBlockIdx(), formerLength);
            zGm.SetGlobalBuffer((__gm__ DTYPE_Z *)z + formerLength * GetBlockIdx(), formerLength);
        } else {
            this->tileLength = tailLength;
            xGm.SetGlobalBuffer(
                (__gm__ DTYPE_X *)x + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
            yGm.SetGlobalBuffer(
                (__gm__ DTYPE_Y *)y + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
            zGm.SetGlobalBuffer(
                (__gm__ DTYPE_Z *)z + formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum), tailLength);
        }
        ASSERT(alignNum != 0 && "align num can not be zero!");
        // 切分后有些数量不满足32B对齐,所以需要对length向上对齐到32B的数据量
        pipe.InitBuffer(inQueueX, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
        pipe.InitBuffer(inQueueY, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM,
            (((this->tileLength + alignNum - 1) / alignNum) * alignNum) * sizeof(half));
}