文档
注册

Kernel侧算子实现

自动生成kernel侧算子实现模板

在算子工程目录下的“op_kernel/xxx.cpp”文件中实现算子的核函数。核函数的定义模板已通过msOpGen工具自动生成,样例如下所示注意这里参数的顺序按照“输入、输出、workspace、tiling”的顺序排布,开发者不要调整其顺序。

#include "kernel_operator.h"
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) {
    GET_TILING_DATA(tiling_data, tiling);// 获取Tiling参数,详见下文介绍
    // TODO: user kernel impl
}
算子原型定义中的输入和输出同名的情况下,自动生成的核函数中,输出参数增加ref后缀予以区分。示例如下:
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR x_ref, GM_ADDR workspace, GM_ADDR tiling) {
    ...
}

GET_TILING_DATA获取Tiling参数

提供GET_TILING_DATA,用于获取算子kernel入口函数传入的tiling信息,并填入注册的Tiling结构体中,此函数会以宏展开的方式进行编译。注意,对应的算子host实现中需要定义TilingData结构体,实现并注册计算TilingData的Tiling函数。具体请参考Host侧Tiling实现

核函数中调用GET_TILING_DATA获取TilingData的样例如下:
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelAdd op;
    op.Init(x, y, z, tilingData.totalLength, tilingData.tileNum);
    if (TILING_KEY_IS(1)) {
        op.Process();
    }
}

核函数内推导输入数据类型和格式

算子工程在核函数内提供了DTYPE_<Arg>、ORIG_DTYPE_<Arg>、FORMAT_<Arg>三种宏用于推导核函数入参的数据类型、原始数据类型和数据格式。其中<Arg>会自动大写。样例如下:
template<class T> func() {}
extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
{
    DTYPE_X temp;
    func<DTYPE_Z>();
    if (FORMAT_Y == FORMAT_ND) {
        ...
    }
}
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词