MrgSort
函数功能
将已经排好序的最多4 条队列,合并排列成 1 条队列,结果按照 score 域由大到小排序。
函数原型
template <typename T> __aicore__ inline void MrgSort(const LocalTensor<T>& dstLocal, const MrgSortSrcList<T>& srcLocal, const MrgSort4Info& params)
参数说明
参数名称 |
输入/输出 |
含义 |
---|---|---|
dstLocal |
输出 |
目的操作数,存储经过排序后的数据。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
srcLocal |
输入 |
源操作数,4个队列,并且每个队列都已经排好序,类型为MrgSortSrcList结构体,具体请参考表2。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
params |
输入 |
排序所需参数,类型为MrgSort4Info结构体,具体请参考表3。 |
参数名称 |
输入/输出 |
含义 |
---|---|---|
src1 |
输入 |
源操作数,第一个已经排好序的队列。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 数据类型与目的操作数保持一致。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
src2 |
输入 |
源操作数,第二个已经排好序的队列。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 数据类型与目的操作数保持一致。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
src3 |
输入 |
源操作数,第三个已经排好序的队列。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 数据类型与目的操作数保持一致。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
src4 |
输入 |
源操作数,第四个已经排好序的队列。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 数据类型与目的操作数保持一致。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float |
参数名称 |
输入/输出 |
含义 |
---|---|---|
elementLengths |
输入 |
四个源队列的长度(8Bytes结构的数目),类型为长度为4的uint16_t数据类型的数组,理论上每个元素取值范围[0, 4095],但不能超出UB的存储空间。 |
ifExhaustedSuspension |
输入 |
某条队列耗尽后,指令是否需要停止,类型为bool,默认false。 |
validBit |
输入 |
有效队列个数,取值如下:
|
repeatTimes |
输入 |
迭代次数,每一次源操作数和目的操作数跳过四个队列总长度。取值范围:repeatTimes∈[1,255]。
repeatTimes参数生效是有条件的,需要同时满足以下四个条件:
|
返回值
无
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
约束说明
- 当存在score[i]与score[j]相同时,如果i>j,则score[j]将首先被选出来,排在前面。
- 每次迭代内的数据会进行排序,不同迭代间的数据不会进行排序。
- 需要注意此函数排序的队列非region proposal结构。
- 操作数地址偏移对齐要求请参见通用约束。
调用示例
- 接口使用样例
// 对8个已排好序的队列进行合并排序,repeatTimes = 2,数据连续存放 // 每个队列包含32个(score,index)的8Bytes结构 // 最后输出对score域的256个数完成排序后的结果 MrgSort4Info params; params.elementLengths[0] = 32; params.elementLengths[1] = 32; params.elementLengths[2] = 32; params.elementLengths[3] = 32; params.ifExhaustedSuspension = false; params.validBit = 0b1111; params.repeatTimes = 2; MrgSortSrcList<float> srcList; srcList.src1 = workLocal[0]; srcList.src2 = workLocal[64]; // workLocal为float类型,每个队列占据256Bytes空间 srcList.src3 = workLocal[128]; srcList.src4 = workLocal[192]; MrgSort<float>(dstLocal, srcList, params);
- 完整样例
#include "kernel_operator.h" namespace AscendC { class KernelMrgSort { public: __aicore__ inline KernelMrgSort() {} __aicore__ inline void Init(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* dstGm) { srcGlobal0.SetGlobalBuffer((__gm__ float*)src0Gm); srcGlobal1.SetGlobalBuffer((__gm__ uint32_t*)src1Gm); dstGlobal.SetGlobalBuffer((__gm__ float*)dstGm); repeat = srcDataSize / 32; pipe.InitBuffer(inQueueSrc0, 1, srcDataSize * sizeof(float)); pipe.InitBuffer(inQueueSrc1, 1, srcDataSize * sizeof(uint32_t)); pipe.InitBuffer(workQueue, 1, dstDataSize * sizeof(float)); pipe.InitBuffer(outQueueDst, 1, dstDataSize * sizeof(float)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { LocalTensor<float> srcLocal0 = inQueueSrc0.AllocTensor<float>(); DataCopy(srcLocal0, srcGlobal0, srcDataSize); inQueueSrc0.EnQue(srcLocal0); LocalTensor<uint32_t> srcLocal1 = inQueueSrc1.AllocTensor<uint32_t>(); DataCopy(srcLocal1, srcGlobal1, srcDataSize); inQueueSrc1.EnQue(srcLocal1); } __aicore__ inline void Compute() { LocalTensor<float> srcLocal0 = inQueueSrc0.DeQue<float>(); LocalTensor<uint32_t> srcLocal1 = inQueueSrc1.DeQue<uint32_t>(); LocalTensor<float> workLocal = workQueue.AllocTensor<float>(); LocalTensor<float> dstLocal = outQueueDst.AllocTensor<float>(); Sort32<float>(workLocal, srcLocal0, srcLocal1, repeat); // 先构造4个排好序的队列 MrgSort4Info params; params.elementLengths[0] = 32; params.elementLengths[1] = 32; params.elementLengths[2] = 32; params.elementLengths[3] = 32; params.ifExhaustedSuspension = false; params.validBit = 0b1111; params.repeatTimes = 1; MrgSortSrcList<float> srcList; srcList.src1 = workLocal[0]; srcList.src2 = workLocal[32 * 1 * 2]; srcList.src3 = workLocal[32 * 2 * 2]; srcList.src4 = workLocal[32 * 3 * 2]; MrgSort<float>(dstLocal, srcList, params); outQueueDst.EnQue<float>(dstLocal); inQueueSrc0.FreeTensor(srcLocal0); inQueueSrc1.FreeTensor(srcLocal1); workQueue.FreeTensor(workLocal); } __aicore__ inline void CopyOut() { LocalTensor<float> dstLocal = outQueueDst.DeQue<float>(); DataCopy(dstGlobal, dstLocal, dstDataSize); outQueueDst.FreeTensor(dstLocal); } private: TPipe pipe; TQue<QuePosition::VECIN, 1> inQueueSrc0; TQue<QuePosition::VECIN, 1> inQueueSrc1; TQue<QuePosition::VECIN, 1> workQueue; TQue<QuePosition::VECOUT, 1> outQueueDst; GlobalTensor<float> srcGlobal0, dstGlobal; GlobalTensor<uint32_t> srcGlobal1; int srcDataSize = 128; int dstDataSize = 256; int repeat = 0; }; } // namespace AscendC extern "C" __global__ __aicore__ void vec_mrgsort_kernel(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* dstGm) { AscendC::KernelMrgSort op; op.Init(src0Gm, src1Gm, dstGm); op.Process(); } 示例结果 输入数据(src0Gm): 128个float类型数据 [2.9447467 7.546607 5.083544 1.6373866 3.4730997 5.488915 6.2410192 6.5340915 9.534971 8.217815 7.922645 9.9135275 9.34575 8.0759535 6.40329 7.2240252 8.792965 4.9348564 7.726399 2.3075738 5.8587966 3.3077633 1.5605974 5.582237 9.38379 8.583278 3.2116296 7.5197206 1.3169404 9.355466 3.6663866 6.3373866 4.188842 1.1831555 6.3235407 7.0127134 1.9593428 9.316625 5.7821383 4.980949 4.4211564 1.0478534 9.626102 4.52559 5.151449 3.4274218 9.874416 8.040044 5.049376 3.8079789 9.16666 7.803004 9.288373 5.497965 2.2784562 8.752271 1.2586805 7.161625 5.807935 2.9983459 4.980592 1.1796398 8.89327 9.35524 5.0074706 2.108345 8.4992285 2.7219095 9.544726 4.4516068 6.940215 1.424632 5.473264 7.7971754 6.730119 3.3760135 1.3578739 8.965629 5.5441265 1.9234481 6.1590824 3.62707 8.257497 6.5762696 3.6241028 1.870233 8.303693 7.5986104 7.211784 9.259263 2.9631793 5.9183855 1.911052 8.445708 3.1592433 5.434683 5.2764387 2.013458 2.5766358 1.3793333 6.4866495 6.957988 8.711433 4.1000323 1.973415 1.5109203 6.830736 7.871973 6.130566 2.5669708 9.317494 4.4140983 8.086401 3.1740563 9.000416 6.2852535 2.170213 4.6842256 5.939913 1.3967329 9.959876 7.9772205 5.874416 4.4834223 3.6719642 8.462775 2.3629668 2.886413 ] 输入数据(src1Gm): [0,0,0,0,...,0] 输出数据(dstGm): [9.959876 0. 9.9135275 0. 9.874416 0. 9.626102 0. 9.544726 0. 9.534971 0. 9.38379 0. 9.355466 0. 9.35524 0. 9.34575 0. 9.317494 0. 9.316625 0. 9.288373 0. 9.259263 0. 9.16666 0. 9.000416 0. 8.965629 0. 8.89327 0. 8.792965 0. 8.752271 0. 8.711433 0. 8.583278 0. 8.4992285 0. 8.462775 0. 8.445708 0. 8.303693 0. 8.257497 0. 8.217815 0. 8.086401 0. 8.0759535 0. 8.040044 0. 7.9772205 0. 7.922645 0. 7.871973 0. 7.803004 0. 7.7971754 0. 7.726399 0. 7.5986104 0. 7.546607 0. 7.5197206 0. 7.2240252 0. 7.211784 0. 7.161625 0. 7.0127134 0. 6.957988 0. 6.940215 0. 6.830736 0. 6.730119 0. 6.5762696 0. 6.5340915 0. 6.4866495 0. 6.40329 0. 6.3373866 0. 6.3235407 0. 6.2852535 0. 6.2410192 0. 6.1590824 0. 6.130566 0. 5.939913 0. 5.9183855 0. 5.874416 0. 5.8587966 0. 5.807935 0. 5.7821383 0. 5.582237 0. 5.5441265 0. 5.497965 0. 5.488915 0. 5.473264 0. 5.434683 0. 5.2764387 0. 5.151449 0. 5.083544 0. 5.049376 0. 5.0074706 0. 4.980949 0. 4.980592 0. 4.9348564 0. 4.6842256 0. 4.52559 0. 4.4834223 0. 4.4516068 0. 4.4211564 0. 4.4140983 0. 4.188842 0. 4.1000323 0. 3.8079789 0. 3.6719642 0. 3.6663866 0. 3.62707 0. 3.6241028 0. 3.4730997 0. 3.4274218 0. 3.3760135 0. 3.3077633 0. 3.2116296 0. 3.1740563 0. 3.1592433 0. 2.9983459 0. 2.9631793 0. 2.9447467 0. 2.886413 0. 2.7219095 0. 2.5766358 0. 2.5669708 0. 2.3629668 0. 2.3075738 0. 2.2784562 0. 2.170213 0. 2.108345 0. 2.013458 0. 1.973415 0. 1.9593428 0. 1.9234481 0. 1.911052 0. 1.870233 0. 1.6373866 0. 1.5605974 0. 1.5109203 0. 1.424632 0. 1.3967329 0. 1.3793333 0. 1.3578739 0. 1.3169404 0. 1.2586805 0. 1.1831555 0. 1.1796398 0. 1.0478534 0. ]