Gather
功能说明
给定输入的张量和一个地址偏移张量,Gather指令根据偏移地址将输入张量按元素收集到结果张量中。
函数原型
- tensor前n个数据计算
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint32_t count)
- tensor高维切分计算
- mask逐bit模式
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint64_t mask[2], const uint8_t repeatTimes, const uint16_t dstRepStride)
- mask连续模式
1 2
template <typename T> __aicore__ inline void Gather(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<uint32_t>& srcOffsetLocal, const uint32_t srcBaseAddr, const uint64_t mask, const uint8_t repeatTimes, const uint16_t dstRepStride)
- mask逐bit模式
参数说明
参数名称 |
输入/输出 |
含义 |
---|---|---|
dstLocal |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:int16_t/uint16_t/int32_t/uint32_t/float/half/bfloat16_t Atlas推理系列产品AI Core,支持的数据类型为:int16_t/uint16_t/int32_t/uint32_t/float/half Atlas 200/500 A2推理产品,支持的数据类型为:uint8_t/int8_t/uint16_t/int16_t/half/uint32_t/int32_t/float |
srcLocal |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 数据类型和dstLocal保持一致。 |
srcOffsetLocal |
输入 |
每个元素在src中对应的地址偏移,地址偏移要大于等于0。该偏移量相对于src的起始基地址而言。每个数值的单位为Bytes。取值应保证src元素类型位宽对齐,否则会导致非预期行为;同时需要保证偏移地址后不能超出UB大小数据的范围。 针对以下型号,地址偏移的取值范围不超出uint32_t的范围即可。 Atlas推理系列产品AI Core Atlas A2训练系列产品/Atlas 800I A2推理产品 针对以下型号,地址偏移的取值范围如下:当操作数为8位时,取值范围为[0, 216-1];当操作数为16位时,取值范围为[0, 217-1],当操作数为32位或者64位时,不超过uint32_t的范围即可,超出取值范围可能导致非预期输出。 Atlas 200/500 A2推理产品 |
srcBaseAddr |
输入 |
srcLocal的起始基地址,单位为Bytes。取值应保证src元素类型位宽对齐,否则会导致非预期行为。 |
count |
输入 |
执行处理的数据个数,不得超过srcLocal和srcOffsetLocal的元素个数。 |
mask |
输入 |
|
repeatTimes |
输入 |
指令迭代次数,每次迭代完成8个datablock(32Bytes)的数据收集,数据范围:repeatTimes∈[0,255]。 |
dstRepStride |
输入 |
相邻迭代间的地址步长,单位是datablock(32Bytes)。 |
支持的型号
Atlas推理系列产品AI Core
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
约束说明
- 操作数地址偏移对齐要求请参见通用约束。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#include "kernel_operator.h" template <typename T> class GatherTest { public: __aicore__ inline GatherTest() {} __aicore__ inline void Init(__gm__ uint8_t* dstGm, __gm__ uint8_t* srcGm, __gm__ uint8_t* srcOffsetGm, const uint32_t count) { m_elementCount = count; m_dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm); m_srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm); m_srcOffsetGlobal.SetGlobalBuffer((__gm__ uint32_t*)srcOffsetGm); m_pipe.InitBuffer(m_queIn, 2, m_elementCount * sizeof(uint32_t)); m_pipe.InitBuffer(m_queOut, 2, m_elementCount * sizeof(uint32_t)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = m_queIn.AllocTensor<T>(); AscendC::DataCopy(srcLocal, m_srcGlobal, m_elementCount); m_queIn.EnQue(srcLocal); AscendC::LocalTensor<uint32_t> srcOffsetLocal = m_queIn.AllocTensor<uint32_t>(); AscendC::DataCopy(srcOffsetLocal, m_srcOffsetGlobal, m_elementCount); m_queIn.EnQue(srcOffsetLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = m_queIn.DeQue<T>(); AscendC::LocalTensor<uint32_t> srcOffsetLocal = m_queIn.DeQue<uint32_t>(); AscendC::LocalTensor<T> dstLocal = m_queOut.AllocTensor<T>(); srcLocal.SetSize(m_elementCount); AscendC::Gather(dstLocal, srcLocal, srcOffsetLocal, (uint32_t)0, m_elementCount); m_queIn.FreeTensor(srcLocal); m_queIn.FreeTensor(srcOffsetLocal); m_queOut.EnQue(dstLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = m_queOut.DeQue<T>(); AscendC::DataCopy(m_dstGlobal, dstLocal, m_elementCount); m_queOut.FreeTensor(dstLocal); } private: AscendC::TPipe m_pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> m_queCalc; AscendC::GlobalTensor<T> m_valueGlobal; uint32_t m_concatRepeatTimes; uint32_t m_sortRepeatTimes; uint32_t m_extractRepeatTimes; uint32_t m_elementCount; AscendC::GlobalTensor<uint32_t> m_srcOffsetGlobal; AscendC::GlobalTensor<T> m_srcGlobal; AscendC::GlobalTensor<T> m_dstGlobal; AscendC::TQue<AscendC::QuePosition::VECIN, 2> m_queIn; AscendC::TQue<AscendC::QuePosition::VECOUT, 2> m_queOut; }; // class GatherTest #define KERNEL_GATHER(T, count) \ extern "C" __global__ __aicore__ void kernel_gather_##T##_##count(GM_ADDR dstGm, GM_ADDR srcGm, \ GM_ADDR srcOffsetGm) \ { \ GatherTest<T> op; \ op.Init(dstGm, srcGm, srcOffsetGm, count); \ op.Process(); \ } |
输入数据srcOffsetLocal: [254 252 250 ... 4 2 0] 输入数据srcLocal(128个half类型数据): [0 1 2 ... 125 126 127] 输出数据dstGlobal: [127 126 125 ... 2 1 0]