SoftmaxFlashV3
功能说明
SoftmaxFlash增强版本,对应Softmax PASA算法。将输入tensor[m0, m1, ..., mt, n](t大于或等于0)的非尾轴长度m0, m1, ..., mt相乘的结果看作m,则输入tensor的shape看作[m, n]。对输入tensor x的尾轴进行切分,分块个数为splitMeanCnt,切分后的tensor为x_cnti。按如下公式进行计算,其中x、inmax、insum、inmean为输入,M、S、E均为输出。
本接口当前只支持ND格式的输入,内部的reduce过程按last轴处理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | def softmax_flash_3(src, height, width, loopCnt, alpha, baseK, inmax=None, insum=None, inmean=None, update=False): scalar = alapha / (1 - alapha) #(m,n)->(m,64) tmpbuffer0 = BlockReduceSum(repeatSize, repeatSize, elementNumPerBlk) remain = int(width / repeatSize - BlkcntPerRepeat) tmpbuffer0 = Add(tmpbuffer0, src, remain, repeatSize * elementNumPerBlk, width) #(m,64)->(m,8) tmpbuffer0 = BlockReduceSum(1, relementNumPerBlk, elementNumPerBlk) #width = baseK * splitMeanCnt rowMeanLocal = tmpbuffer0 / baseK rowMeanGlobal = np.mean(src, axis=(-1), keepdims=True) rowMeanGlobalTmp = (rowMeanGlobal - rowMeanLocal) * scalar src = src - rowMeanGlobalTmp if update == False: x_mean = rowMeanGlobal maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) x_sum = np.sum(dst, axis=-1, keepdims=True) exp_max = None return dst, x_max, x_sum, x_mean, exp_max else: x_mean = (rowMeanGlobal + inmean * (loopCnt - 1)) / loopCnt maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar shiftPrev = (inmean - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = shiftPrev + inmax x_max = np.max(np.concatenate((x_max, maxTmp), axis=(-1)), axis=(-1), keepdims=True) maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) exp_max = np.exp(inmax - x_max + shiftPrev) x_sum = np.sum(x_exp, axis=-1, keepdims=True) x_sum = exp_max * insum + x_sum return x_exp, x_max, x_sum, x_mean, exp_max |
函数原型
- 接口框架申请临时空间
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor, const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxParams& params)
- 通过sharedTmpBuffer入参传入临时空间
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor,const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor,const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxParams& params)
由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
- 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。
- 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过SoftmaxFlashV3 Tiling接口中提供的GetSoftMaxFlashV3MaxMinTmpSize接口获取所需最小和最大临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。
参数说明
参数名 |
描述 |
---|---|
T |
输入srcTensor及输出dstTensor、expMaxTensor操作数的数据类型。 |
U |
输入inMeanTensor、inExpSumTensor、inMaxTensor及输出meanTensor、expSumTensor、maxTensor操作数的数据类型。 |
isUpdate |
是否使能update为true的计算。 |
isReuseSource |
预留参数,暂未启用,必须使用默认值false。 |
isBasicBlock |
预留参数,暂未启用,必须使用默认值false。 |
isDataFormatNZ |
预留参数,暂未启用,必须使用默认值false。 |
config |
预留参数,暂未启用,必须使用默认值SOFTMAX_DEFAULT_CFG。 |
参数名 |
输入/输出 |
描述 |
||
---|---|---|---|---|
dstTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 dstTensor的shape和源操作数srcTensor一致。 |
||
meanTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 用于保存softmax计算过程中平均值的结果。
|
||
expSumTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 用于保存softmax计算过程中reducesum的结果。
|
||
maxTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 用于保存softmax计算过程中reducemax的结果。
|
||
srcTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 last轴长度需要32Byte对齐。 |
||
expMaxTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
inMeanTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 softmax计算所需要的mean值。
|
||
inExpSumTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 softmax计算所需要的sum值。
|
||
inMaxTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 softmax计算所需要的max值。
|
||
sharedTmpBuffer |
输入 |
临时空间。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 该操作数的数据类型固定uint8_t。 接口内部复杂计算时用于存储中间变量,由开发者提供。 临时空间大小BufferSize的获取方式请参考SoftmaxFlashV3 Tiling接口。 |
||
tiling |
输入 |
SoftmaxFlashV3接口计算所需Tiling信息,Tiling信息的获取请参考SoftmaxFlashV3 Tiling接口。 |
||
params |
输入 |
srcTensor的shape信息和计算相关参数。SoftMaxParams类型,具体定义如下:
注意,当前本接口不支持非对齐场景,因此参数srcM与oriSrcM相等,参数srcK与oriSrcK相等。 |
返回值
无
支持的型号
约束说明
- 操作数地址偏移对齐要求请参见通用约束。
- 对于输入srcTensor需要满足:尾轴长度n大于等于512,同时n是64的倍数;非尾轴长度的乘积m为8的倍数。
- srcTensor和dstTensor的Tensor的空间可以复用,meanTensor和inMeanTensor的空间可以复用,maxTensor和inMaxTensor的空间可以复用,expSumTensor和inExpSumTensor的空间可以复用。
- meanTensor、expSumTensor、maxTensor、expMaxTensor、inMeanTensor、inExpSumTensor、inMaxTensor的Tensor空间,last轴长度必须是32Byte。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | #include "kernel_operator.h" template <typename T, typename U> class KernelSoftmaxFlashV3 { public: __aicore__ inline KernelSoftmaxFlashV3() {} __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData) { srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm); dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm); maxGlobal.SetGlobalBuffer((__gm__ U *)inMaxGm); sumGlobal.SetGlobalBuffer((__gm__ U *)inSumGm); meanGlobal.SetGlobalBuffer((__gm__ U *)inMeanGm); pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T)); elementNumPerBlk1 = 32 / sizeof(U); pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(meanQueue, 1, height * elementNumPerBlk1 * sizeof(U)); elementNumPerBlk2 = 32 / sizeof(T); pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk2 * sizeof(T)); pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T)); tiling = tilingData; } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.AllocTensor<U>(); AscendC::DataCopy(srcLocal, srcGlobal, height * width); AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmeanLocal, meanGlobal, height * elementNumPerBlk1); inQueueSrc.EnQue(srcLocal); sumQueue.EnQue(insumLocal); maxQueue.EnQue(inmaxLocal); meanQueue.EnQue(inmeanLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.DeQue<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.DeQue<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.DeQue<U>(); AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>(); AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>(); AscendC::SoftMaxParams params = {height, width, height, width, loopCnt, splitMeanCnt, alpha}; AscendC::SoftmaxFlashV3<T, U, true>(dstLocal, inmeanLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, inmeanLocal, insumLocal, inmaxLocal, tiling, params); outQueueDst.EnQue<T>(dstLocal); maxQueue.FreeTensor(inmaxLocal); sumQueue.FreeTensor(insumLocal); meanQueue.FreeTensor(inmeanLocal); inQueueSrc.FreeTensor(srcLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>(); AscendC::DataCopy(dstGlobal, dstLocal, height * width); outQueueDst.FreeTensor(dstLocal); } private: AscendC::TPipe pipe; AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrc; AscendC::TQue<AscendC::TPosition::VECIN, 1> meanQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> maxQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> sumQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> expMaxQueue; AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueDst; AscendC::GlobalTensor<T> srcGlobal, dstGlobal; AscendC::GlobalTensor<U> meanGlobal, maxGlobal, sumGlobal; uint32_t elementNumPerBlk1 = 0; uint32_t elementNumPerBlk2 = 0; uint32_t width = 1024; uint32_t height = 8; uint32_t loopCnt = 2; uint32_t splitMeanCnt = 8; float alpha = 0.9375; SoftMaxTiling tiling; }; extern "C" __global__ __aicore__ void softmax_flashv3_kernel(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling) { GET_TILING_DATA(tilingData, tiling); KernelSoftmaxFlashV3<half, float> op; op.Init(srcGm, inMaxGm, inSumGm, inMeanGm, dstGm, tilingData.softmaxTilingData); op.Process(); } |