下载
中文
注册

Sort32

函数功能

排序函数,一次迭代可以完成32个数的排序,数据需要按如下描述结构进行保存:

score和 index分别存储在src0Local和src1Local中,按score进行排序(score大的排前面),排序好的score与其对应的index一起以(score, index)的结构存储在 dstLocal中。不论 score 为 half 还是 float 类型,dstLocal 中的(score, index)结构总是占据8 Bytes空间。

如下所示:

  • 当score为float,index为uint32类型时,计算结果中index存储在高4Bytes,score存储在低4Bytes。

  • 当score为half,index为uint32类型时,计算结果中index存储在高4Bytes,score存储在低2Bytes, 中间的2Bytes保留。

函数原型

template <typename T> __aicore__ inline void Sort32(const LocalTensor<T>& dstLocal, const LocalTensor<T>& src0Local, const LocalTensor<uint32_t>& src1Local, const int32_t repeatTimes)

参数说明

表1 参数说明

参数名称

输入/输出

含义

dstLocal

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

src0Local

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

此源操作数的数据类型需要与目的操作数保持一致。

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

src1Local

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

此源操作数固定为uint32_t数据类型。

repeatTimes

输入

重复迭代次数,int32_t类型,每次迭代完成32个元素的排序,下次迭代src0Local和src1Local各跳过32个elements,dstLocal跳过32*8 Byte空间。取值范围:repeatTimes∈[0,255]。

返回值

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

约束说明

  • 当存在score[i]与score[j]相同时,如果i>j,则score[j]将首先被选出来,排在前面。
  • 每次迭代内的数据会进行排序,不同迭代间的数据不会进行排序。
  • 操作数地址偏移对齐要求请参见通用约束

调用示例

  • 接口使用样例
    // repeatTimes = 4, 对128个数分成4组进行排序,每次完成1组32个数的排序
    Sort32<float>(dstLocal, srcLocal0, srcLocal1, 4);
  • 完整样例
    #include "kernel_operator.h"
    
    namespace AscendC {
    class KernelSort32 {
    public:
        __aicore__ inline KernelSort32() {}
        __aicore__ inline void Init(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* dstGm)
        {
            srcGlobal0.SetGlobalBuffer((__gm__ float*)src0Gm);
            srcGlobal1.SetGlobalBuffer((__gm__ uint32_t*)src1Gm);
            dstGlobal.SetGlobalBuffer((__gm__ float*)dstGm);
    
            repeat = srcDataSize / 32;
            pipe.InitBuffer(inQueueSrc0, 1, srcDataSize * sizeof(float));
            pipe.InitBuffer(inQueueSrc1, 1, srcDataSize * sizeof(uint32_t));
            pipe.InitBuffer(outQueueDst, 1, dstDataSize * sizeof(float));
        }
        __aicore__ inline void Process()
        {
            CopyIn();
            Compute();
            CopyOut();
        }
    
    private:
        __aicore__ inline void CopyIn()
        {
            LocalTensor<float> srcLocal0 = inQueueSrc0.AllocTensor<float>();
            DataCopy(srcLocal0, srcGlobal0, srcDataSize);
            inQueueSrc0.EnQue(srcLocal0);
            LocalTensor<uint32_t> srcLocal1 = inQueueSrc1.AllocTensor<uint32_t>();
            DataCopy(srcLocal1, srcGlobal1, srcDataSize);
            inQueueSrc1.EnQue(srcLocal1);
        }
        __aicore__ inline void Compute()
        {
            LocalTensor<float> srcLocal0 = inQueueSrc0.DeQue<float>();
            LocalTensor<uint32_t> srcLocal1 = inQueueSrc1.DeQue<uint32_t>();
            LocalTensor<float> dstLocal = outQueueDst.AllocTensor<float>();
    
            Sort32<float>(dstLocal, srcLocal0, srcLocal1, repeat);
     
            outQueueDst.EnQue<float>(dstLocal);
            inQueueSrc0.FreeTensor(srcLocal0);
            inQueueSrc1.FreeTensor(srcLocal1);
        }
        __aicore__ inline void CopyOut()
        {
            LocalTensor<float> dstLocal = outQueueDst.DeQue<float>();
            DataCopy(dstGlobal, dstLocal, dstDataSize);
            outQueueDst.FreeTensor(dstLocal);
        }
    
    private:
        TPipe pipe;
        TQue<QuePosition::VECIN, 1> inQueueSrc0;
        TQue<QuePosition::VECIN, 1> inQueueSrc1;
        TQue<QuePosition::VECOUT, 1> outQueueDst;
        GlobalTensor<float> srcGlobal0, dstGlobal;
        GlobalTensor<uint32_t> srcGlobal1;
        int srcDataSize = 128;
        int dstDataSize = 256;
        int repeat = 0;
    };
    } // namespace AscendC
    
    extern "C" __global__ __aicore__ void vec_sort32_kernel(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm,
        __gm__ uint8_t* dstGm)
    {
        AscendC::KernelSort32 op;
        op.Init(src0Gm, src1Gm, dstGm);
        op.Process();
    }
    示例结果
    输入数据(src0Gm):128个float类型数据
    [7.867878  9.065992  9.374247  1.0911566 9.262053  2.035779  3.747487
     2.9315646 5.237765  5.176559  7.965426  3.2341435 7.203623  1.5736973
     3.386001  5.077001  4.593656  1.8485032 7.8554387 5.1269145 7.223478
     8.259627  5.5502934 8.795028  9.626377  7.7227993 9.505127  6.683293
     6.232041  2.1760664 4.504409  2.906819  9.425597  9.467169  4.990563
     4.609341  1.8662999 3.6319377 3.5542917 8.382838  5.133566  3.1391478
     5.244712  9.330158  2.0394793 5.9761605 4.937267  6.076068  7.5449195
     6.5085726 1.8132887 2.5047603 3.3350103 2.7831945 3.0417829 5.0608244
     3.4855423 2.8485715 4.853921  6.364753  3.1402998 6.052516  3.6143537
     4.0714087 6.8068676 8.625871  8.040528  1.9881475 4.618402  7.0302424
     6.0751796 5.877218  9.256125  4.193431  5.2048235 6.9774013 2.8765092
     5.8294353 8.618196  8.619784  3.9252923 4.491909  6.0063663 2.3781579
     5.8828945 7.269731  6.1864734 8.32413   5.2518435 9.184813  7.9312286
     3.8841062 8.540505  7.611145  8.204335  2.110103  4.1796618 7.2383223
     3.9992998 4.750733  8.650443  7.6469994 6.6126637 8.993322  8.920976
     7.143699  7.0797443 3.3189814 7.3707795 3.26992   8.58087   5.6882014
     2.0333889 6.711474  4.353861  7.946233  4.5678067 6.3354545 4.092168
     2.416961  3.6823056 4.6000533 2.4727547 4.7993317 1.159995  8.025275
     3.3826146 3.8543346]
    输入数据(src1Gm):
    [0,0,0,0,0...0]
    输出数据(dstGm):
    [9.626377  0.        9.505127  0.        9.374247  0.        9.262053
     0.        9.065992  0.        8.795028  0.        8.259627  0.
     7.965426  0.        7.867878  0.        7.8554387 0.        7.7227993
     0.        7.223478  0.        7.203623  0.        6.683293  0.
     6.232041  0.        5.5502934 0.        5.237765  0.        5.176559
     0.        5.1269145 0.        5.077001  0.        4.593656  0.
     4.504409  0.        3.747487  0.        3.386001  0.        3.2341435
     0.        2.9315646 0.        2.906819  0.        2.1760664 0.
     2.035779  0.        1.8485032 0.        1.5736973 0.        1.0911566
     0.        9.467169  0.        9.425597  0.        9.330158  0.
     8.382838  0.        7.5449195 0.        6.5085726 0.        6.364753
     0.        6.076068  0.        6.052516  0.        5.9761605 0.
     5.244712  0.        5.133566  0.        5.0608244 0.        4.990563
     0.        4.937267  0.        4.853921  0.        4.609341  0.
     4.0714087 0.        3.6319377 0.        3.6143537 0.        3.5542917
     0.        3.4855423 0.        3.3350103 0.        3.1402998 0.
     3.1391478 0.        3.0417829 0.        2.8485715 0.        2.7831945
     0.        2.5047603 0.        2.0394793 0.        1.8662999 0.
     1.8132887 0.        9.256125  0.        9.184813  0.        8.625871
     0.        8.619784  0.        8.618196  0.        8.540505  0.
     8.32413   0.        8.204335  0.        8.040528  0.        7.9312286
     0.        7.611145  0.        7.269731  0.        7.0302424 0.
     6.9774013 0.        6.8068676 0.        6.1864734 0.        6.0751796
     0.        6.0063663 0.        5.8828945 0.        5.877218  0.
     5.8294353 0.        5.2518435 0.        5.2048235 0.        4.618402
     0.        4.491909  0.        4.193431  0.        3.9252923 0.
     3.8841062 0.        2.8765092 0.        2.3781579 0.        2.110103
     0.        1.9881475 0.        8.993322  0.        8.920976  0.
     8.650443  0.        8.58087   0.        8.025275  0.        7.946233
     0.        7.6469994 0.        7.3707795 0.        7.2383223 0.
     7.143699  0.        7.0797443 0.        6.711474  0.        6.6126637
     0.        6.3354545 0.        5.6882014 0.        4.7993317 0.
     4.750733  0.        4.6000533 0.        4.5678067 0.        4.353861
     0.        4.1796618 0.        4.092168  0.        3.9992998 0.
     3.8543346 0.        3.6823056 0.        3.3826146 0.        3.3189814
     0.        3.26992   0.        2.4727547 0.        2.416961  0.
     2.0333889 0.        1.159995  0.       ]