Select
函数功能
给定两个源操作数src0和src1,根据selMask(用于选择的Mask掩码)的比特位值选取元素,得到目的操作数dst。选择的规则为:当selMask的比特位是1时,从src0中选取,比特位是0时从src1选取。
对于tensor高维切分计算接口,支持根据mask参数对上述选取结果,再次进行过滤,有效位填入最终的dst,无效位则保持dst原始值。例如:src0为[1,2,3,4,5,6,7,8],src1为[9,10,11,12,13,14,15,16],selMask为[0,0,0,0,1,1,1,1],mask为[1,1,1,1,0,0,0,0],dst原始值为[-1,-2,-3,-4,-5,-6,-7,-8],则根据selMask的比特位选取后的结果dst_temp为:[9,10,11,12,5,6,7,8],然后再根据mask进行过滤,dst的最终输出结果为[9,10,11,12,-5,-6,-7,-8]。
本选择功能支持三种模式:
- 模式0:根据selMask在两个tensor中选取元素。selMask中有效数据的个数存在限制,当输入数据类型为8位时,selMask前256比特位有效;当输入数据类型为16位时,selMask前128比特位有效;输入数据类型为32位时,selMask前64比特位有效。在每一轮迭代中,根据selMask的有效位数据进行选择操作,每一轮迭代采用的selMask,均为相同数值,即selMask的有效数值。
- 模式1:根据selMask在1个tensor和1个scalar标量中选取元素,selMask无有效数据限制。
- 模式2:根据selMask在两个tensor中选取元素,selMask无有效数据限制:例如迭代次数为2,处理的数据类型为half,第一个迭代根据selMask第一段128bit位值选择src0或src1元素,第二个迭代根据selMask第二段128bit位值选择src0或是src1元素。
函数原型
- tensor前n个数据计算
- Select模式1:
1 2
template <typename T, typename U> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint32_t calCount)
- Select模式0和Select模式2:
1 2
template <typename T, typename U> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint32_t calCount)
- Select模式1:
- tensor高维切分计算
- Select模式1:
- mask参数使用逐bit模式,逐bit模式的具体介绍请参考参数说明中的mask参数说明:
1 2
template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint64_t mask[], uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- mask参数使用连续模式,连续模式的具体介绍请参考参数说明中的mask参数说明:
1 2
template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint64_t mask, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- 不传入mask参数:(需要和SetVectorMask、SetCmpMask(ISASI)配合使用)
1 2
template <typename T, typename U> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- mask参数使用逐bit模式,逐bit模式的具体介绍请参考参数说明中的mask参数说明:
- Select模式0和Select模式2:
- mask参数使用逐bit模式:
1 2
template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint64_t mask[], uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- mask参数使用连续模式:
1 2
template <typename T, typename U, bool isSetMask = true> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint64_t mask, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- 不传入mask参数:(需要和SetVectorMask、SetCmpMask(ISASI)配合使用)
1 2
template <typename T, SELMODE selMode> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)
- mask参数使用逐bit模式:
- Select模式1:
参数说明
参数名称 |
含义 |
---|---|
T |
源操作数和目的操作数的数据类型。 |
U |
selMask的数据类型。 |
isSetMask |
保留参数,保持默认值即可。如需使用在接口外部设置mask的功能,可以调用不传入mask参数的接口来实现。 |
selMode |
同表2 参数说明中的selMode参数说明。 |
参数名称 |
输入/输出 |
含义 |
---|---|---|
dstLocal |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。
|
selMask |
输入 |
选取mask。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 支持的数据类型为:uint8_t/uint16_t/uint32_t/uint64_t。 每个bit表示1个element的选取,当selMask的比特位是1时从src0中选取,比特位是0时从src1选取。
|
src0Local |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。
|
src1Local |
输入 |
源操作数。
|
selMode |
输入 |
指令模式,有以下3种;
|
mask |
输入 |
|
repeatTimes |
输入 |
重复迭代次数。矢量计算单元,每次读取连续的256 Bytes数据进行计算,为完成对输入数据的处理,必须通过多次迭代(repeat)才能完成所有数据的读取与计算。repeatTimes表示迭代的次数。 |
repeatParams |
输入 |
控制操作数地址步长的参数。BinaryRepeatParams类型,包含操作数相邻迭代间相同datablock的地址步长,操作数同一迭代内不同datablock的地址步长等参数。 相邻迭代间的地址步长参数说明请参考repeatStride;同一迭代内datablock的地址步长参数说明请参考dataBlockStride。 |
calCount |
输入 |
输入数据元素个数。 参数取值范围和操作数的数据类型有关,数据类型不同,能够处理的元素个数最大值不同。 矢量计算单元,每个迭代读取连续256 Bytes数据进行计算,通过多次迭代完成所有数据的读取与计算。所以当操作数为8位时,calCount∈[1,256*255],255表示迭代次数的最大值,256表示每次迭代内能够处理256个8位数据;当操作数为16位时,calCount∈[1,128*255],255表示迭代次数的最大值,128表示每次迭代内能够处理128个16位数据;当操作数为32位时,calCount∈[1,64*255],64表示每次迭代内能够处理64个32位数据。 |
返回值
无
支持的型号
注意事项
- 为了节省地址空间,开发者可以定义一个Tensor,供源操作数与目的操作数同时使用(即地址重叠),相关约束如下:
- 使用tensor高维切分计算接口(repeatTimes=1)时或者使用tensor前n个数据计算接口时,源操作数与目的操作数之间要求100%完全重叠,不支持部分重叠。
- 使用tensor高维切分计算接口(repeatTimes>1),若源操作数与目的操作数之间存在依赖,即第N次迭代的目的操作数是第N+1次的源操作数,这种情况是不支持地址重叠的。
- 操作数地址偏移对齐要求请参见通用约束。
- 针对
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 ,对于模式1和模式2,使用时需要预留8K的Unified Buffer空间,作为接口的临时数据存放区。 - 针对
Atlas 推理系列产品AI Core ,对于模式1和模式2,使用时需要预留8K的Unified Buffer空间,作为接口的临时数据存放区。
调用示例
本样例中只展示Compute流程中的部分代码。如果您需要运行样例代码,请将该代码段拷贝并替换样例模板中Compute函数的部分代码即可。
- Select-tensor高维切分计算样例(模式2)
1 2 3 4 5 6 7
uint64_t mask = 256/sizeof(float); int repeat = 4; AscendC::BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 }; // repeat = 4, 64 elements one repeat, 256 elements total // dstBlkStride, src0BlkStride, src1BlkStride = 1, no gap between blocks in one repeat // dstRepStride, src0RepStride, src1RepStride = 8, no gap between repeats AscendC::Select(dstLocal, maskLocal, src0Local, src1Local, AscendC::SELMODE::VSEL_TENSOR_TENSOR_MODE, mask, repeat, repeatParams);
- Select-tensor前n个数据计算样例(模式1)
1
AscendC::Select(dstLocal, maskLocal, src0Local, static_cast<float>(0), AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE, dataSize);
- Select-tensor前n个数据计算样例(模式0,每轮repeat迭代中,maskLocal只有前64bit生效)
1
AscendC::Select(dstLocal, maskLocal, src0Local, src1Local, AscendC::SELMODE::VSEL_CMPMASK_SPR, dataSize);
- Select-tensor高维切分计算样例-mask连续模式(模式0,每轮repeat迭代中,maskLocal只有前64bit生效)
1 2 3 4 5 6 7
uint64_t mask = 256/sizeof(float); int repeat = 4; AscendC::BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 }; // repeat = 4, 64 elements one repeat, 256 elements total // dstBlkStride, src0BlkStride, src1BlkStride = 1, no gap between blocks in one repeat // dstRepStride, src0RepStride, src1RepStride = 8, no gap between repeats AscendC::Select(dstLocal, maskLocal, src0Local, src1Local, AscendC::SELMODE::VSEL_CMPMASK_SPR, mask, repeat, repeatParams);
- Select-tensor高维切分计算样例-mask逐bit模式(每轮repeat迭代中,maskLocal只有前64bit生效)
1 2 3 4 5 6 7
uint64_t mask[2] = { UINT64_MAX, 0}; int repeat = 4; AscendC::BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 }; // repeat = 4, 64 elements one repeat, 256 elements total // srcBlkStride, = 1, no gap between blocks in one repeat // dstRepStride, srcRepStride = 8, no gap between repeats AscendC::Select(dstLocal, maskLocal, src0Local, src1Local, AscendC::SELMODE::VSEL_CMPMASK_SPR, mask, repeat, repeatParams);
结果示例如下:
模式2示例:
输入数据(src0_gm): [-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759] 输入数据(src1_gm): [35.8789, 44.0334, 54.9997, 44.8567, -30.8579, -53.714, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, 64.5595, 17.0849, 45.2641, -63.2115, -98.7838, -52.3835, -65.9849, 50.4909, 69.9812, -22.3447, -32.3809, -97.8394, -45.4997, 63.5391, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, -26.7797, 98.2463, -78.9606, -62.4907, -13.5348, -49.5058, -4.06369, 77.2982, -32.8221, 84.6766, -62.3829, 58.8673, -75.8509, -95.3497, -79.7642, 67.1185, 34.4278, 34.5305, -76.1646, 53.1497, -12.3158, -42.9392, 59.2962, -4.12072, 47.1292, -17.0687, -78.0087, -59.4565, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 52.7964, -0.932984, 70.8957, -68.1249, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, 66.1438, -22.2332, 84.0665, -7.86752, -2.38648, 1.37756, -98.691, -35.847, -15.2647, -85.2363, -54.3978, -46.6612, 99.3826, -75.7728, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -80.625, 19.6418, 51.0559, 35.3667, -56.306, -41.2088, 0.955906, -85.7743, 8.18112, 36.4615, -0.572343, -16.0821, 36.0277, -4.61647, 26.5385, 88.6082, 9.17454, 44.8951, -42.173, 51.5339, 5.93139, 93.7096, -68.8219, 68.2573, -67.325, -88.4579, -56.8873, -75.8117, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -57.8738, 63.1893, -40.4731, -1.00914, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, 26.8912, -82.1719, 26.8713, -56.5484, 35.4743, -8.59957, -12.4709, 27.8249, 76.6877, -27.5806, 63.2649, 66.1106, 15.8328, 9.19251, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -64.1341, 91.4046, 26.7268, -92.1002, -34.7002, -6.41819, -18.15, 12.207, 48.6667, -39.4883, -21.0939, -50.3433, 58.2913, 7.64983, -82.6098, -89.6739, -25.9494, 82.4803, 20.8037, 21.483, -29.0788, 31.7695, 50.462, -83.7715, 63.4177, 52.7679, -90.2271, 16.1258, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -23.039, -52.7595, -97.0177, -13.1399, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, -49.0936, -43.5545, 47.3574, 97.1801, 43.4392, 22.7347, 12.6125, 63.7829, 22.3428, 53.4543, -91.4307, 45.6971, -92.1851, -81.4774, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, -61.0205, -8.00347, -60.0369, 56.2894, -38.1932, 17.976, 5.82004, 4.41524, -52.2192, 93.1915, 21.1114, -29.3558, -18.5685, 20.7356, -4.71108, -0.947533, 73.0143, 62.5668, 96.1632, 41.4265, -89.503, -83.7747, -97.6047, -60.7304, 28.9736, -42.6681, 55.2584, 59.1584, -14.6596, -41.1826, 48.8083] 输入数据(sel_Gm): [60, 32, 10, 7, 42, 52, 38, 26, 19, 15, 18, 83, 41, 43, 91, 30, 45, 77, 80, 58, 34, 76, 44, 4, 64, 45, 48, 31, 30, 56, 43, 88] 输出数据(dst_gm): [35.8789,44.0334,-50.6124,-72.3737,-33.7107,-83.4001,-59.8013,71.1663,46.3484,8.56818,-59.4716,6.07412,-39.0137,-53.1848,17.0849,45.2641,-63.2115,6.53722,-52.3835,37.6655,50.4909,69.9812,-22.3447,-32.3809,23.0464,28.5975,-46.3033,-69.3535,43.3368,98.8541,-77.2888,1.02385,20.4965,10.7653,98.2463,-65.4774,-62.4907,-89.6462,-49.5058,-4.06369,77.2982,-32.8221,-97.9312,-62.3829,19.8888,-64.0522,-95.3497,-79.7642,67.1185,94.274,-72.868,-76.1646,53.1497,87.0729,-42.9392,59.2962,-4.12072,4.84555,-17.0687,-93.7445,71.8209,-98.9565,-54.0959,56.5437,79.6631,66.8743,52.7964,-0.932984,-67.0339,-68.1249,85.895,25.4119,-73.8953,-63.5915,99.4875,-46.2296,66.0614,66.1438,-22.2332,84.0665,-7.86752,42.1465,1.37756,-98.691,-89.2164,-15.2647,-85.2363,-54.3978,48.1772,-27.7686,-75.7728,-31.2539,-15.5663,92.4507,-84.2826,-60.8802,-42.4957,-80.625,19.6418,-83.5848,35.3667,63.3315,-41.2088,0.955906,96.6872,91.8399,36.4615,78.5923,-16.0821,26.3331,-4.61647,26.5385,96.9554,50.3262,44.8951,-71.5939,-97.0042,5.93139,42.3891,-68.8219,68.2573,59.1608,-4.57918,-81.3639,-37.2619,-40.5576,-98.378,32.3699,-88.6893,-73.1523,-39.3391,-33.7141,-40.4731,-94.1271,-80.6115,44.2928,-67.5639,-29.3298,-37.5219,11.9601,-20.6412,26.8912,22.7411,26.8713,-56.5484,35.4743,-8.59957,-12.4709,11.4884,76.6877,21.1073,63.2649,66.1106,-22.4703,9.19251,-24.0218,-63.2767,-2.72752,-8.37093,55.6421,-99.0591,-85.9841,91.4046,26.7268,-92.1002,2.42074,-6.41819,-18.15,12.207,48.6667,51.6616,-48.073,-50.3433,58.2913,-89.9345,-82.6098,-89.6739,-25.9494,-31.1646,69.4103,21.483,-78.7341,31.7695,50.462,-83.7715,63.4177,71.7272,-90.2271,16.1258,-61.4531,-61.7242,25.0575,-97.8702,26.9708,-23.039,-52.7595,-97.0177,-13.1399,-60.527,-29.7551,-33.3166,-82.1242,-36.9426,-70.3942,0.0381027,5.28754,-43.5545,47.3574,97.1801,43.4392,22.7347,12.6125,49.2601,-88.5728,53.4543,-91.4307,19.083,-81.5139,87.8383,90.714,4.75546,-79.7464,69.0971,-82.6252,-63.0042,56.4225,-72.1826,91.5082,68.2155,-38.1932,17.976,5.82004,4.41524,-52.2192,93.1915,-69.4351,48.6597,-8.44998,20.7356,-4.71108,-59.2265,19.407,62.5668,-15.4888,41.4265,59.5144,-83.7747,-97.6047,-60.7304,28.9736,-42.6681,-26.9334,79.3798,-14.6596,22.984,48.8083]
模式1示例:
输入数据(src0_gm): [-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759] 输入数据(sel_Gm): [60, 32, 10, 7, 42, 52, 38, 26, 19, 15, 18, 83, 41, 43, 91, 30, 45, 77, 80, 58, 34, 76, 44, 4, 64, 45, 48, 31, 30, 56, 43, 88 ] 输出数据(dst_gm): [0,0,-50.6124,-72.3737,-33.7107,-83.4001,0,0,0,0,0,0,0,-53.1848,0,0,0,6.53722,0,37.6655,0,0,0,0,23.0464,28.5975,-46.3033,0,0,0,0,0,0,10.7653,0,-65.4774,0,-89.6462,0,0,0,0,-97.9312,0,19.8888,-64.0522,0,0,0,94.274,-72.868,0,0,87.0729,0,0,0,4.84555,0,-93.7445,71.8209,0,0,0,79.6631,66.8743,0,0,-67.0339,0,0,0,-73.8953,-63.5915,99.4875,-46.2296,0,0,0,0,0,42.1465,0,0,-89.2164,0,0,0,48.1772,-27.7686,0,0,-15.5663,0,-84.2826,0,-42.4957,0,0,-83.5848,0,63.3315,0,0,96.6872,91.8399,0,78.5923,0,26.3331,0,0,96.9554,50.3262,0,-71.5939,-97.0042,0,42.3891,0,0,59.1608,-4.57918,-81.3639,-37.2619,0,0,0,-88.6893,0,-39.3391,-33.7141,0,-94.1271,0,0,-67.5639,0,-37.5219,11.9601,0,0,22.7411,0,0,0,0,0,11.4884,0,21.1073,0,0,-22.4703,0,-24.0218,-63.2767,-2.72752,0,0,0,-85.9841,0,0,0,2.42074,0,0,0,0,51.6616,-48.073,0,0,-89.9345,0,0,0,-31.1646,69.4103,0,-78.7341,0,0,0,0,71.7272,0,0,0,0,0,0,0,0,0,0,0,-60.527,0,-33.3166,0,-36.9426,-70.3942,0,5.28754,0,0,0,0,0,0,49.2601,-88.5728,0,0,19.083,-81.5139,87.8383,90.714,4.75546,0,0,0,0,56.4225,-72.1826,91.5082,68.2155,0,0,0,0,0,0,-69.4351,48.6597,-8.44998,0,0,-59.2265,19.407,0,-15.4888,0,59.5144,0,0,0,0,0,-26.9334,79.3798,0,22.984,0]
输入数据(src0_gm): [-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759] 输入数据(src1_gm): [35.8789, 44.0334, 54.9997, 44.8567, -30.8579, -53.714, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, 64.5595, 17.0849, 45.2641, -63.2115, -98.7838, -52.3835, -65.9849, 50.4909, 69.9812, -22.3447, -32.3809, -97.8394, -45.4997, 63.5391, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, -26.7797, 98.2463, -78.9606, -62.4907, -13.5348, -49.5058, -4.06369, 77.2982, -32.8221, 84.6766, -62.3829, 58.8673, -75.8509, -95.3497, -79.7642, 67.1185, 34.4278, 34.5305, -76.1646, 53.1497, -12.3158, -42.9392, 59.2962, -4.12072, 47.1292, -17.0687, -78.0087, -59.4565, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 52.7964, -0.932984, 70.8957, -68.1249, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, 66.1438, -22.2332, 84.0665, -7.86752, -2.38648, 1.37756, -98.691, -35.847, -15.2647, -85.2363, -54.3978, -46.6612, 99.3826, -75.7728, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -80.625, 19.6418, 51.0559, 35.3667, -56.306, -41.2088, 0.955906, -85.7743, 8.18112, 36.4615, -0.572343, -16.0821, 36.0277, -4.61647, 26.5385, 88.6082, 9.17454, 44.8951, -42.173, 51.5339, 5.93139, 93.7096, -68.8219, 68.2573, -67.325, -88.4579, -56.8873, -75.8117, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -57.8738, 63.1893, -40.4731, -1.00914, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, 26.8912, -82.1719, 26.8713, -56.5484, 35.4743, -8.59957, -12.4709, 27.8249, 76.6877, -27.5806, 63.2649, 66.1106, 15.8328, 9.19251, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -64.1341, 91.4046, 26.7268, -92.1002, -34.7002, -6.41819, -18.15, 12.207, 48.6667, -39.4883, -21.0939, -50.3433, 58.2913, 7.64983, -82.6098, -89.6739, -25.9494, 82.4803, 20.8037, 21.483, -29.0788, 31.7695, 50.462, -83.7715, 63.4177, 52.7679, -90.2271, 16.1258, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -23.039, -52.7595, -97.0177, -13.1399, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, -49.0936, -43.5545, 47.3574, 97.1801, 43.4392, 22.7347, 12.6125, 63.7829, 22.3428, 53.4543, -91.4307, 45.6971, -92.1851, -81.4774, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, -61.0205, -8.00347, -60.0369, 56.2894, -38.1932, 17.976, 5.82004, 4.41524, -52.2192, 93.1915, 21.1114, -29.3558, -18.5685, 20.7356, -4.71108, -0.947533, 73.0143, 62.5668, 96.1632, 41.4265, -89.503, -83.7747, -97.6047, -60.7304, 28.9736, -42.6681, 55.2584, 59.1584, -14.6596, -41.1826, 48.8083] 输入数据(sel_Gm): [60, 32, 10, 7, 42, 52, 38, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 15, 18, 83, 41, 43, 91, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 77, 80, 58, 34, 76, 44, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 45, 48, 31, 30, 56, 43, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 输出数据(dst_gm): [35.8789, 44.0334, -50.6124, -72.3737, -33.7107, -83.4001, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, -53.1848, 17.0849, 45.2641, -63.2115, 6.53722, -52.3835, 37.6655, 50.4909, 69.9812, -22.3447, -32.3809, 23.0464, 28.5975, -46.3033, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, 10.7653, 98.2463, -65.4774, -62.4907, -89.6462, -49.5058, -4.06369, 77.2982, -32.8221, -97.9312, -62.3829, 19.8888, -64.0522, -95.3497, -79.7642, 67.1185, 94.274, -72.868, -76.1646, 53.1497, 87.0729, -42.9392, 59.2962, -4.12072, 4.84555, -17.0687, -93.7445, 71.8209, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 18.1016, -27.7082, -67.0339, -11.9576, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, -67.2079, -22.2332, 84.0665, -7.86752, 42.1465, 1.37756, 82.3653, -35.847, -15.2647, -85.2363, -54.3978, 48.1772, -27.7686, 26.1484, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -23.7061, 19.6418, -83.5848, 35.3667, 63.3315, -41.2088, 0.955906, -85.7743, 8.18112, 33.9888, -0.572343, -30.4885, 26.3331, -4.61647, 26.5385, 88.6082, 50.3262, 66.612, -42.173, 51.5339, 71.4549, 93.7096, -68.8219, 68.2573, 59.1608, -88.4579, -81.3639, -37.2619, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -39.3391, -33.7141, -88.6628, -94.1271, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, -30.1169, -82.1719, 26.8713, -56.5484, -37.6203, -8.59957, 5.09318, 27.8249, 76.6877, -27.5806, 63.2649, 11.9436, -22.4703, -58.5243, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -85.9841, 91.4046, 48.9581, -92.1002, 2.42074, -6.41819, -18.15, 12.207, 48.6667, 51.6616, -21.0939, 74.2754, -51.1623, 7.64983, -82.6098, -89.6739, 79.6587, -31.1646, 20.8037, 21.483, -78.7341, 31.7695, 50.462, -83.7715, -80.3328, 52.7679, -77.979, -76.6814, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -5.50804, -83.4637, -56.8385, 51.5406, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, 5.28754, -43.5545, 47.3574, 97.1801, 70.5489, 22.7347, -53.5746, 63.7829, 22.3428, 53.4543, -91.4307, 19.083, -81.5139, 87.8383, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, 56.4225, -8.00347, 91.5082, 56.2894, -81.7476, 17.976, 5.82004, 4.41524, -52.2192, -16.5589, 21.1114, 48.6597, -8.44998, 20.7356, -4.71108, -0.947533, 19.407, 88.1542, 96.1632, 41.4265, 59.5144, -83.7747, -97.6047, -60.7304, -31.3635, -42.6681, -26.9334, 79.3798, -14.6596, -41.1826, 48.8083]
样例模板
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
#include "kernel_operator.h" class KernelSelect { public: __aicore__ inline KernelSelect() {} __aicore__ inline void Init(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* selGm, __gm__ uint8_t* dstGm) { src0Global.SetGlobalBuffer((__gm__ float*)src0Gm); src1Global.SetGlobalBuffer((__gm__ float*)src1Gm); selMaskGlobal.SetGlobalBuffer((__gm__ uint8_t*)selGm); dstGlobal.SetGlobalBuffer((__gm__ float*)dstGm); pipe.InitBuffer(inQueueSrc0, 1, dataSize * sizeof(float)); pipe.InitBuffer(inQueueSrc1, 1, dataSize * sizeof(float)); pipe.InitBuffer(inQueueSelMask, 1, selDataSize * sizeof(uint8_t)); pipe.InitBuffer(outQueueDst, 1, dataSize * sizeof(float)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<float> src0Local = inQueueSrc0.AllocTensor<float>(); AscendC::LocalTensor<float> src1Local = inQueueSrc1.AllocTensor<float>(); AscendC::LocalTensor<uint8_t> selMaskLocal = inQueueSelMask.AllocTensor<uint8_t>(); AscendC::DataCopy(src0Local, src0Global, dataSize); AscendC::DataCopy(src1Local, src1Global, dataSize); AscendC::DataCopy(selMaskLocal, selMaskGlobal, selDataSize); inQueueSrc0.EnQue(src0Local); inQueueSrc1.EnQue(src1Local); inQueueSelMask.EnQue(selMaskLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<float> src0Local = inQueueSrc0.DeQue<float>(); AscendC::LocalTensor<float> src1Local = inQueueSrc1.DeQue<float>(); AscendC::LocalTensor<uint8_t> maskLocal = inQueueSelMask.DeQue<uint8_t>(); AscendC::LocalTensor<float> dstLocal = outQueueDst.AllocTensor<float>(); AscendC::Select(dstLocal, maskLocal, src0Local, src1Local, AscendC::SELMODE::VSEL_CMPMASK_SPR, dataSize); outQueueDst.EnQue<float>(dstLocal); inQueueSrc0.FreeTensor(src0Local); inQueueSrc1.FreeTensor(src1Local); inQueueSelMask.FreeTensor(maskLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<float> dstLocal = outQueueDst.DeQue<float>(); AscendC::DataCopy(dstGlobal, dstLocal, dataSize); outQueueDst.FreeTensor(dstLocal); } private: AscendC::TPipe pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueSrc0, inQueueSrc1, inQueueSelMask; AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueDst; AscendC::GlobalTensor<float> src0Global, src1Global, dstGlobal; AscendC::GlobalTensor<uint8_t> selMaskGlobal; uint32_t dataSize = 256; uint32_t oneSelectDataSize = 256/ sizeof(float); uint32_t selDataSize = dataSize / oneSelectDataSize * 32; // (模式1和模式2时,uint32_t selDataSize = dataSize / 8;) }; extern "C" __global__ __aicore__ void main_sel_demo(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* selGm, __gm__ uint8_t* dstGm) { KernelSelect op; op.Init(src0Gm, src1Gm, selGm, dstGm); op.Process(); } |