下文将以Matmul+LeakyRelu融合算子的实现为例,介绍Mix融合算子的设计和实现流程。该样例仅支持在Atlas A2训练系列产品上运行。
算子的设计过程分为算子分析、数据流分析、Tiling策略设计三部分。
算子分析是指明确算子的数学表达式、输入、输出,核函数的名称等信息。
c = LeakyRelu(a * b + bias, alpha);
算子类型(OpType) |
MATMUL_LEAKYRELU |
|||
---|---|---|---|---|
算子输入 |
name |
shape |
data type |
format |
a |
[M, K] |
half |
ND |
|
b |
[K, N] |
half |
ND |
|
bias |
[1, N] |
float32 |
- |
|
alpha |
- |
float32 |
- |
|
算子输出 |
c |
[M, N] |
float32 |
ND |
核函数名称 |
matmul_leakyrelu_custom |
进行算子的数据流分析:数据流向为在Cube核上完成Matmul计算后将数据搬运至Vector核进行LeakyRelu计算。根据上述数据流并结合融合算子的编程范式,规划并行的流水任务。如下图所示:
Leaky ReLU(带泄露线性整流函数)激活函数,是人工神经网络中一种常用的激活函数,其数学表达式和函数图像如下所示:
前三步的内容都封装在Matmul高阶API内,本样例中可以简化为3个stage。如下图所示:
根据上述分析,明确实现过程中会使用到Matmul高阶API接口,LeakyRelu Vector计算接口、DataCopy接口、EnQue、DeQue接口。
Tiling策略的设计主要包括多核切分和核内切分策略。
切分策略示意图如下,更多切分策略相关原理请参考数据分块(Tiling)。
在矩阵编程章节,我们得知Ascend C提供一组Matmul高阶API,封装了常用的切分和数据搬运、计算的算法逻辑,方便用户快速实现Matmul矩阵乘法的运算操作。融合算子中的矩阵编程的部分实现与之类似,开发者在host侧通过调用API自动获取Tiling参数,该参数传递到kernel侧后,在初始化操作时传入,通过几个简单的API即可完成矩阵乘操作。再结合上文的融合算子的编程范式,融合算子实现的步骤如下:
kernel侧实现的代码框架如下,在完成Matmul对象的初始化、左矩阵A、右矩阵B、Bias的设置后,通过单次Iterate叠加while循环的方式完成后续的Matmul计算、LeakyRelu计算、CopyOut流程。
template<typename aType, typename bType, typename cType, typename biasType> __aicore__ inline void MatmulLeakyKernel<aType, bType, cType, biasType>::Process(){ uint32_t computeRound = 0; // Matmul对象初始化 REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), matmulObj); // 设置Matmul的输入(包括左矩阵、右矩阵、bias) matmulObj.Init(&tiling); matmulObj.SetTensorA(aGlobal); matmulObj.SetTensorB(bGlobal); matmulObj.SetBias(biasGlobal); // 调用matmul iterate获取一块[baseM, baseN]的计算结果 while (matmulObj.template Iterate<true>()) { MatmulCompute(); LeakyReluCompute(); CopyOut(computeRound); computeRound++; } matmulObj.End(); }
Matmul计算、LeakyRelu计算、CopyOut的具体实现代码如下:
template<typename aType, typename bType, typename cType, typename biasType> __aicore__ inline void MatmulLeakyKernel<aType, bType, cType, biasType>::Process(){ uint32_t computeRound = 0; // ... // 调用matmul iterate获取一块[baseM, baseN]的计算结果 while (matmulObj.template Iterate<true>()) { MatmulCompute(); // ... computeRound++; } matmulObj.End(); } template<typename aType, typename bType, typename cType, typename biasType> __aicore__ inline void MatmulLeakyKernel<aType, bType, cType, biasType>::MatmulCompute(){ reluOutLocal = reluOutQueue_.AllocTensor<cType>(); // 调用GetTensorC将Matmul的计算结果搬运到Vector核。 matmulObj.template GetTensorC<true>(reluOutLocal, false, true); }
// 调用LeakyRule接口进行计算 template<typename aType, typename bType, typename cType, typename biasType> __aicore__ inline void MatmulLeakyKernel<aType, bType, cType, biasType>::LeakyReluCompute(){ LeakyRelu(reluOutLocal, reluOutLocal, (cType)alpha, tiling.baseM * tiling.baseN); reluOutQueue_.EnQue(reluOutLocal); }
// 将结果搬出到GM template<typename aType, typename bType, typename cType, typename biasType> __aicore__ inline void MatmulLeakyKernel<aType, bType, cType, biasType>::CopyOut(uint32_t count){ reluOutQueue_.DeQue<cType>(); const uint32_t roundM = tiling.singleCoreM / tiling.baseM; const uint32_t roundN = tiling.singleCoreN / tiling.baseN; uint32_t startOffset = (count % roundM * tiling.baseM * tiling.N + count / roundM * tiling.baseN); DataCopyParams copyParam = {(uint16_t)tiling.baseM, (uint16_t)(tiling.baseN * sizeof(cType) / DEFAULT_C0_SIZE), 0, (uint16_t)((tiling.N - tiling.baseN) * sizeof(cType) / DEFAULT_C0_SIZE)}; DataCopy(cGlobal[startOffset], reluOutLocal, copyParam); reluOutQueue_.FreeTensor(reluOutLocal); }
host侧自动获取Tiling参数的关键步骤介绍如下:
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); MultiCoreMatmulTiling cubeTiling(ascendcPlatform);
创建对象时需要传入需要传入硬件平台信息,硬件平台信息可以通过GetPlatformInfo获取。
cubeTiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetBType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetCType(TPosition::LCM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetBiasType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
cubeTiling.SetShape(M, N, K); cubeTiling.SetOrgShape(M, N, K);
cubeTiling.SetBufferSpace(-1, -1, -1);
cubeTiling.SetBias(true);
MatmulLeakyreluCustomTilingData tiling; if (cubeTiling.GetTiling(tiling.cubeTilingData) == -1){ return ge::GRAPH_FAILED; }
size_t userWorkspaceSize = 0; size_t systemWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); size_t *currentWorkspace = context->GetWorkspaceSizes(1); currentWorkspace[0] = userWorkspaceSize + systemWorkspaceSize;
// 使用Matmul时必须设置workspace空间 SetSysWorkspace(workspace); if (GetSysWorkSpacePtr() == nullptr) { return; }