文档
注册

Select

函数功能

给定两个源操作数src0和src1,根据selMask(用于选择的Mask掩码)的比特位值选取元素,得到目的操作数dst。选择的规则为:当selMask的比特位是1时,从src0中选取,比特位是0时从src1选取。

对于tensor高维切分计算接口,支持根据mask参数对上述选取结果,再次进行过滤,有效位填入最终的dst,无效位则保持dst原始值。例如:src0为[1,2,3,4,5,6,7,8],src1为[9,10,11,12,13,14,15,16],selMask为[0,0,0,0,1,1,1,1],mask为[1,1,1,1,0,0,0,0],dst原始值为[-1,-2,-3,-4,-5,-6,-7,-8],则根据selMask的比特位选取后的结果dst_temp为:[9,10,11,12,5,6,7,8],然后再根据mask进行过滤,dst的最终输出结果为[9,10,11,12,-5,-6,-7,-8]。

本选择功能支持三种模式:

  • 模式0:根据selMask在两个tensor中选取元素,SelMask中有效数据的个数存在限制,当输入数据类型为16位时,SelMask前128比特位有效,输入数据类型为32位时,SelMask前64比特位有效,在每一轮迭代中,根据SelMask的有效位数据进行选择操作,每一轮迭代采用的selMask,均为相同数值,即SelMask的有效数值。
  • 模式1:根据selMask在1个tensor和1个scalar标量中选取元素,SelMask无有效数据限制。
  • 模式2:根据selMask在两个tensor中选取元素,SelMask无有效数据限制:例如迭代次数为2,处理的数据类型为half,第一个迭代根据SelMask第一段128bit位值选择src0或src1元素,第二个迭代根据SelMask第二段128bit位值选择src0或是src1元素。

函数原型

  • tensor前n个数据计算
    • Select模式1:

      template <typename T, typename U> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint32_t calCount)

    • Select模式0和Select模式2:

      template <typename T, typename U> __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint32_t calCount)

  • tensor高维切分计算
    • Select模式1:
      • mask参数使用逐bit模式,逐bit模式的具体介绍请参考参数说明中的mask参数说明:

        template <typename T, typename U>

        __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint64_t mask[2], uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)

      • mask参数使用连续模式,连续模式的具体介绍请参考参数说明中的mask参数说明:

        template <typename T, typename U>

        __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, T src1Local, SELMODE selMode, uint64_t mask, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)

    • Select模式0和Select模式2:
      • mask参数使用逐bit模式:

        template <typename T, typename U>

        __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint64_t mask[2], uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)

      • mask参数使用连续模式:

        template <typename T, typename U>

        __aicore__ inline void Select(const LocalTensor<T>& dstLocal, const LocalTensor<U>& selMask, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, SELMODE selMode, uint64_t mask, uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)

参数说明

表1 参数说明

参数名称

输入/输出

含义

dstLocal

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas 训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品(Ascend 310P处理器)AI Core,支持的数据类型为:half/float

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

Atlas 200/500 A2推理产品,支持的数据类型为:half/float

selMask

输入

选取mask,每个bit表示1个element的选取,当selMask的比特位是1时从src0中选取,比特位是0时从src1选取。

mode为模式1/2时,多次迭代连续消耗,支持的数据类型为:Tensor(uint8/uint16/uint32/uint64);

mode为模式0时,源操作数为half类型,selMask仅输入数据的前128bit有效,源操作数为float类型,selMask仅输入数据的前64bit有效,多次迭代,selMask的生效数值不变。

Atlas 训练系列产品,仅支持mode=0

Atlas 训练系列产品,支持的数据类型为:uint8_t/uint16_t/uint32_t/uint64_t

Atlas推理系列产品(Ascend 310P处理器)AI Core,支持的数据类型为:uint8_t/uint16_t/uint32_t/uint64_t

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:uint8_t/uint16_t/uint32_t/uint64_t

Atlas 200/500 A2推理产品,支持的数据类型为:uint8_t/uint16_t/uint32_t/uint64_t

src0Local

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas 训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品(Ascend 310P处理器)AI Core,支持的数据类型为:half/float

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

Atlas 200/500 A2推理产品,支持的数据类型为:half/float

src1Local

输入

源操作数。

  • 当 selMode 为 模式0,模式2时:

    类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

  • 当 selMode 为 模式1 时,类型为 T,配置为Scalar/立即数(half/float)。

Atlas 训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品(Ascend 310P处理器)AI Core,支持的数据类型为:half/float

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

Atlas 200/500 A2推理产品,支持的数据类型为:half/float

selMode

输入

指令模式,有以下3种;

  • 模式0:取值为SELMODE::VSEL_CMPMASK_SPR。根据selMask在两个tensor中选取元素,SelMask有位数限制,不管迭代多少次,每次迭代都只根据截取后的固定位数的SelMask进行选择。输入数据为half数据类型时,选取SelMask的前128bit,输入数据为float类型时,选取SelMask的前64bit。
  • 模式1:取值为VSEL_TENSOR_SCALAR_MODE。根据selMask在1个tensor和1个scalar标量中选取元素。支持多次迭代,选取方式为,根据selMask的设置值,如果selMask比特值为1,则选择src0Local内的同位置数值,如果selMask比特值为0,则选择标量值。selMask连续存放,当输入数据类型为float时,一次比较获取selMask64bit长度的数据,当数据类型为half时,一次比较获取selMask128bit长度的数据。
  • 模式2:取值为VSEL_TENSOR_TENSOR_MODE。根据selMask在两个tensor中选取元素。支持多次迭代,选取方式为,根据selMask的设置值,如果selMask比特值为1,则选择src0Local内的同位置数值,如果selMask比特值为0,则选择src1Local内的同位置数值。selMask连续存放,当输入数据类型为float时,一次比较获取selMask64bit长度的数据,当数据类型为half时,一次比较获取selMask128bit长度的数据。

Atlas 训练系列产品,仅支持模式0。

Atlas推理系列产品(Ascend 310P处理器)AI Core,支持模式0、1、2。

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持模式0、1、2。

Atlas 200/500 A2推理产品,支持模式0、1、2。

mask

输入

mask用于控制每次迭代内参与计算的元素。

  • 连续模式:表示前面连续的多少个元素参与计算。数据类型为uint64。取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask∈[1, 128];当操作数为32位时,mask∈[1, 64]。
  • 逐bit模式:可以按位控制哪些元素参与计算,bit位的值为1表示参与计算,0表示不参与。参数类型为长度为2的uint64_t类型数组。

    例如,mask=[8, 0],8=0b1000,表示仅第4个元素参与计算。

    参数取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask[0]、mask[1]∈[0, 264-1]并且不同时为0;当操作数为32位时,mask[1]为0,mask[0]∈(0, 264-1];当操作数为64位时,mask[1]为0,mask[0]∈(0, 232-1]。

repeatTimes

输入

重复迭代次数。矢量计算单元,每次读取连续的256 Bytes数据进行计算,为完成对输入数据的处理,必须通过多次迭代(repeat)才能完成所有数据的读取与计算。repeatTimes表示迭代的次数。

关于该参数的具体描述请参考通用参数说明

BinaryRepeatParams

输入

控制操作数地址步长的数据结构。结构体内包含操作数相邻迭代间相同block的地址步长,操作数同一迭代内不同block的地址步长等参数。

该数据结构的定义请参考BinaryRepeatParams

相邻迭代间的地址步长参数说明请参考Repeat stride(相邻迭代间相同datablock的地址步长);同一迭代内datablock的地址步长参数说明请参考Block stride(同一迭代内不同datablock的地址步长)

calCount

输入

输入数据元素个数。

参数取值范围和操作数的数据类型有关,数据类型不同,能够处理的元素个数最大值不同。

矢量计算单元,每个迭代读取连续256 Bytes数据进行计算,通过多次迭代完成所有数据的读取与计算。所以当操作数为16位时,calCount∈[1,128*255],255表示迭代次数的最大值,128表示每次迭代内能够处理128个16位数据;当操作数为32位时,calCount∈[1,64*255],64表示每次迭代内能够处理64个32位数据。

返回值

支持的型号

Atlas 训练系列产品

Atlas推理系列产品(Ascend 310P处理器)AI Core

Atlas A2训练系列产品/Atlas 800I A2推理产品

Atlas 200/500 A2推理产品

注意事项

  • 为了节省地址空间,开发者可以定义一个Tensor,供源操作数与目的操作数同时使用(即地址重叠),相关约束如下:
    • 使用tensor高维切分计算接口(repeatTimes=1)时或者使用tensor前n个数据计算接口时,源操作数与目的操作数之间要求100%完全重叠,不支持部分重叠。
    • 使用tensor高维切分计算接口(repeatTimes>1),若源操作数与目的操作数之间存在依赖,即第N次迭代的目的操作数是第N+1次的源操作数,这种情况是不支持地址重叠的。
  • 操作数地址偏移对齐要求请参见通用约束
  • 针对Atlas A2训练系列产品/Atlas 800I A2推理产品,对于模式1和模式2,使用时需要预留8K的Unified Buffer空间,作为接口的临时数据存放区。
  • 针对Atlas推理系列产品(Ascend 310P处理器)AI Core,对于模式1和模式2,使用时需要预留8K的Unified Buffer空间,作为接口的临时数据存放区。

调用示例

本样例中只展示Compute流程中的部分代码。如果您需要运行样例代码,请将该代码段拷贝并替换样例模板中Compute函数的部分代码即可。

  • Select-tensor高维切分计算样例(模式2)
    uint64_t mask = 256/sizeof(float);
    int repeat = 4;
    BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 };
    // repeat = 4, 64 elements one repeat, 256 elements total
    // dstBlkStride, src0BlkStride, src1BlkStride = 1, no gap between blocks in one repeat
    // dstRepStride, src0RepStride, src1RepStride = 8, no gap between repeats
    Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, mask, repeat, repeatParams);
  • Select-tensor前n个数据计算样例(模式1)
    Select(dstLocal, maskLocal, src0Local, static_cast<float>(0), SELMODE::VSEL_TENSOR_SCALAR_MODE, dataSize);
  • Select-tensor前n个数据计算样例(模式0,每轮repeat迭代中,maskLocal只有前64bit生效)
    Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_CMPMASK_SPR, dataSize);
  • Select-tensor高维切分计算样例-mask连续模式(模式0,每轮repeat迭代中,maskLocal只有前64bit生效)
    uint64_t mask = 256/sizeof(float);
    int repeat = 4;
    BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 };
    // repeat = 4, 64 elements one repeat, 256 elements total
    // dstBlkStride, src0BlkStride, src1BlkStride = 1, no gap between blocks in one repeat
    // dstRepStride, src0RepStride, src1RepStride = 8, no gap between repeats
    Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_CMPMASK_SPR, mask, repeat, repeatParams);
  • Select-tensor高维切分计算样例-mask逐bit模式(每轮repeat迭代中,maskLocal只有前64bit生效)
    uint64_t mask[2] = { UINT64_MAX, 0};
    int repeat = 4;
    BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 };
    // repeat = 4, 64 elements one repeat, 256 elements total
    // srcBlkStride, = 1, no gap between blocks in one repeat
    // dstRepStride, srcRepStride = 8, no gap between repeats
    Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_CMPMASK_SPR, mask, repeat, repeatParams);

结果示例如下:

模式2示例:

输入数据(src0_gm): 
[-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759]
输入数据(src1_gm):
[35.8789, 44.0334, 54.9997, 44.8567, -30.8579, -53.714, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, 64.5595, 17.0849, 45.2641, -63.2115, -98.7838, -52.3835, -65.9849, 50.4909, 69.9812, -22.3447, -32.3809, -97.8394, -45.4997, 63.5391, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, -26.7797, 98.2463, -78.9606, -62.4907, -13.5348, -49.5058, -4.06369, 77.2982, -32.8221, 84.6766, -62.3829, 58.8673, -75.8509, -95.3497, -79.7642, 67.1185, 34.4278, 34.5305, -76.1646, 53.1497, -12.3158, -42.9392, 59.2962, -4.12072, 47.1292, -17.0687, -78.0087, -59.4565, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 52.7964, -0.932984, 70.8957, -68.1249, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, 66.1438, -22.2332, 84.0665, -7.86752, -2.38648, 1.37756, -98.691, -35.847, -15.2647, -85.2363, -54.3978, -46.6612, 99.3826, -75.7728, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -80.625, 19.6418, 51.0559, 35.3667, -56.306, -41.2088, 0.955906, -85.7743, 8.18112, 36.4615, -0.572343, -16.0821, 36.0277, -4.61647, 26.5385, 88.6082, 9.17454, 44.8951, -42.173, 51.5339, 5.93139, 93.7096, -68.8219, 68.2573, -67.325, -88.4579, -56.8873, -75.8117, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -57.8738, 63.1893, -40.4731, -1.00914, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, 26.8912, -82.1719, 26.8713, -56.5484, 35.4743, -8.59957, -12.4709, 27.8249, 76.6877, -27.5806, 63.2649, 66.1106, 15.8328, 9.19251, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -64.1341, 91.4046, 26.7268, -92.1002, -34.7002, -6.41819, -18.15, 12.207, 48.6667, -39.4883, -21.0939, -50.3433, 58.2913, 7.64983, -82.6098, -89.6739, -25.9494, 82.4803, 20.8037, 21.483, -29.0788, 31.7695, 50.462, -83.7715, 63.4177, 52.7679, -90.2271, 16.1258, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -23.039, -52.7595, -97.0177, -13.1399, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, -49.0936, -43.5545, 47.3574, 97.1801, 43.4392, 22.7347, 12.6125, 63.7829, 22.3428, 53.4543, -91.4307, 45.6971, -92.1851, -81.4774, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, -61.0205, -8.00347, -60.0369, 56.2894, -38.1932, 17.976, 5.82004, 4.41524, -52.2192, 93.1915, 21.1114, -29.3558, -18.5685, 20.7356, -4.71108, -0.947533, 73.0143, 62.5668, 96.1632, 41.4265, -89.503, -83.7747, -97.6047, -60.7304, 28.9736, -42.6681, 55.2584, 59.1584, -14.6596, -41.1826, 48.8083]
输入数据(sel_Gm): 
[60, 32, 10, 7, 42, 52, 38, 26, 19, 15, 18, 83, 41, 43, 91, 30, 45, 77, 80, 58, 34, 76, 44, 4, 64, 45, 48, 31, 30, 56, 43, 88]
输出数据(dst_gm): 
[35.8789,44.0334,-50.6124,-72.3737,-33.7107,-83.4001,-59.8013,71.1663,46.3484,8.56818,-59.4716,6.07412,-39.0137,-53.1848,17.0849,45.2641,-63.2115,6.53722,-52.3835,37.6655,50.4909,69.9812,-22.3447,-32.3809,23.0464,28.5975,-46.3033,-69.3535,43.3368,98.8541,-77.2888,1.02385,20.4965,10.7653,98.2463,-65.4774,-62.4907,-89.6462,-49.5058,-4.06369,77.2982,-32.8221,-97.9312,-62.3829,19.8888,-64.0522,-95.3497,-79.7642,67.1185,94.274,-72.868,-76.1646,53.1497,87.0729,-42.9392,59.2962,-4.12072,4.84555,-17.0687,-93.7445,71.8209,-98.9565,-54.0959,56.5437,79.6631,66.8743,52.7964,-0.932984,-67.0339,-68.1249,85.895,25.4119,-73.8953,-63.5915,99.4875,-46.2296,66.0614,66.1438,-22.2332,84.0665,-7.86752,42.1465,1.37756,-98.691,-89.2164,-15.2647,-85.2363,-54.3978,48.1772,-27.7686,-75.7728,-31.2539,-15.5663,92.4507,-84.2826,-60.8802,-42.4957,-80.625,19.6418,-83.5848,35.3667,63.3315,-41.2088,0.955906,96.6872,91.8399,36.4615,78.5923,-16.0821,26.3331,-4.61647,26.5385,96.9554,50.3262,44.8951,-71.5939,-97.0042,5.93139,42.3891,-68.8219,68.2573,59.1608,-4.57918,-81.3639,-37.2619,-40.5576,-98.378,32.3699,-88.6893,-73.1523,-39.3391,-33.7141,-40.4731,-94.1271,-80.6115,44.2928,-67.5639,-29.3298,-37.5219,11.9601,-20.6412,26.8912,22.7411,26.8713,-56.5484,35.4743,-8.59957,-12.4709,11.4884,76.6877,21.1073,63.2649,66.1106,-22.4703,9.19251,-24.0218,-63.2767,-2.72752,-8.37093,55.6421,-99.0591,-85.9841,91.4046,26.7268,-92.1002,2.42074,-6.41819,-18.15,12.207,48.6667,51.6616,-48.073,-50.3433,58.2913,-89.9345,-82.6098,-89.6739,-25.9494,-31.1646,69.4103,21.483,-78.7341,31.7695,50.462,-83.7715,63.4177,71.7272,-90.2271,16.1258,-61.4531,-61.7242,25.0575,-97.8702,26.9708,-23.039,-52.7595,-97.0177,-13.1399,-60.527,-29.7551,-33.3166,-82.1242,-36.9426,-70.3942,0.0381027,5.28754,-43.5545,47.3574,97.1801,43.4392,22.7347,12.6125,49.2601,-88.5728,53.4543,-91.4307,19.083,-81.5139,87.8383,90.714,4.75546,-79.7464,69.0971,-82.6252,-63.0042,56.4225,-72.1826,91.5082,68.2155,-38.1932,17.976,5.82004,4.41524,-52.2192,93.1915,-69.4351,48.6597,-8.44998,20.7356,-4.71108,-59.2265,19.407,62.5668,-15.4888,41.4265,59.5144,-83.7747,-97.6047,-60.7304,28.9736,-42.6681,-26.9334,79.3798,-14.6596,22.984,48.8083]

模式1示例:

输入数据(src0_gm): 
[-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759]
输入数据(sel_Gm): 
[60, 32, 10, 7, 42, 52, 38, 26, 19, 15, 18, 83, 41, 43, 91, 30, 45, 77, 80, 58, 34, 76, 44, 4, 64, 45, 48, 31, 30, 56, 43, 88 ]
输出数据(dst_gm): 
[0,0,-50.6124,-72.3737,-33.7107,-83.4001,0,0,0,0,0,0,0,-53.1848,0,0,0,6.53722,0,37.6655,0,0,0,0,23.0464,28.5975,-46.3033,0,0,0,0,0,0,10.7653,0,-65.4774,0,-89.6462,0,0,0,0,-97.9312,0,19.8888,-64.0522,0,0,0,94.274,-72.868,0,0,87.0729,0,0,0,4.84555,0,-93.7445,71.8209,0,0,0,79.6631,66.8743,0,0,-67.0339,0,0,0,-73.8953,-63.5915,99.4875,-46.2296,0,0,0,0,0,42.1465,0,0,-89.2164,0,0,0,48.1772,-27.7686,0,0,-15.5663,0,-84.2826,0,-42.4957,0,0,-83.5848,0,63.3315,0,0,96.6872,91.8399,0,78.5923,0,26.3331,0,0,96.9554,50.3262,0,-71.5939,-97.0042,0,42.3891,0,0,59.1608,-4.57918,-81.3639,-37.2619,0,0,0,-88.6893,0,-39.3391,-33.7141,0,-94.1271,0,0,-67.5639,0,-37.5219,11.9601,0,0,22.7411,0,0,0,0,0,11.4884,0,21.1073,0,0,-22.4703,0,-24.0218,-63.2767,-2.72752,0,0,0,-85.9841,0,0,0,2.42074,0,0,0,0,51.6616,-48.073,0,0,-89.9345,0,0,0,-31.1646,69.4103,0,-78.7341,0,0,0,0,71.7272,0,0,0,0,0,0,0,0,0,0,0,-60.527,0,-33.3166,0,-36.9426,-70.3942,0,5.28754,0,0,0,0,0,0,49.2601,-88.5728,0,0,19.083,-81.5139,87.8383,90.714,4.75546,0,0,0,0,56.4225,-72.1826,91.5082,68.2155,0,0,0,0,0,0,-69.4351,48.6597,-8.44998,0,0,-59.2265,19.407,0,-15.4888,0,59.5144,0,0,0,0,0,-26.9334,79.3798,0,22.984,0]
模式0示例:
输入数据(src0_gm): 
[-80.4933, 52.2499, -50.6124, -72.3737, -33.7107, -83.4001, 34.3954, 61.3188, 96.5484, 27.1321, -56.8153, 9.80549, 9.11199, -53.1848, -77.2548, -0.0681466, -69.5783, 6.53722, -22.5986, 37.6655, -25.4146, 89.232, 55.4716, 21.6069, 23.0464, 28.5975, -46.3033, 50.0312, -42.6339, 41.8752, -87.0426, 37.9717, 10.4336, 10.7653, -30.6943, -65.4774, 8.38653, -89.6462, 65.1115, 42.2134, -91.1666, -84.6927, -97.9312, 98.861, 19.8888, -64.0522, -27.1243, 72.7673, -9.9489, 94.274, -72.868, 43.1349, 84.1897, 87.0729, 87.2606, 34.5548, 87.7985, 4.84555, -10.2156, -93.7445, 71.8209, -63.4942, 45.6619, 93.4737, 79.6631, 66.8743, 18.1016, -27.7082, -67.0339, -11.9576, 52.0373, -11.5452, -73.8953, -63.5915, 99.4875, -46.2296, 75.453, -67.2079, 89.8868, -19.9666, 30.5359, 42.1465, -19.8105, 82.3653, -89.2164, -0.959167, -50.0723, -30.3058, 48.1772, -27.7686, 26.1484, 94.8462, -15.5663, -87.346, -84.2826, -58.8268, -42.4957, -23.7061, 67.0375, -83.5848, 64.167, 63.3315, -33.3809, 35.1264, 96.6872, 91.8399, 33.9888, 78.5923, -30.4885, 26.3331, -62.3014, -30.3431, 96.9554, 50.3262, 66.612, -71.5939, -97.0042, 71.4549, 42.3891, 71.308, 72.3209, 59.1608, -4.57918, -81.3639, -37.2619, 28.2445, 16.7995, -46.5868, -88.6893, 82.0504, -39.3391, -33.7141, -88.6628, -94.1271, -74.7738, -80.0798, -67.5639, -69.8237, -37.5219, 11.9601, -30.3912, -30.1169, 22.7411, -85.9541, 19.5141, -37.6203, -49.5693, 5.09318, 11.4884, 18.9713, 21.1073, -84.9266, 11.9436, -22.4703, -58.5243, -24.0218, -63.2767, -2.72752, -87.8947, -91.5162, 33.4207, -85.9841, 18.743, 48.9581, 69.3992, 2.42074, -75.0209, -53.2579, -45.4509, 66.6121, 51.6616, -48.073, 74.2754, -51.1623, -89.9345, 4.15238, -4.47531, 79.6587, -31.1646, 69.4103, -83.5936, -78.7341, 56.8626, 72.8834, -27.0248, -80.3328, 71.7272, -77.979, -76.6814, -14.9994, -94.5054, -75.2802, -96.4931, -17.6781, -5.50804, -83.4637, -56.8385, 51.5406, -60.527, -11.0762, -33.3166, -54.9609, -36.9426, -70.3942, 28.3439, 5.28754, -61.4775, 96.0657, -69.1967, 70.5489, 32.817, -53.5746, 49.2601, -88.5728, -1.94822, -3.16238, 19.083, -81.5139, 87.8383, 90.714, 4.75546, 31.9277, -4.1301, -0.160932, -31.2602, 56.4225, -72.1826, 91.5082, 68.2155, -81.7476, -14.0418, -79.4093, -30.7375, 38.8967, -16.5589, -69.4351, 48.6597, -8.44998, -74.4274, -20.3394, -59.2265, 19.407, 88.1542, -15.4888, 60.9066, 59.5144, -42.6935, 20.3518, 11.7192, -31.3635, 26.0055, -26.9334, 79.3798, 46.4724, 22.984, 67.5759]
输入数据(src1_gm):
[35.8789, 44.0334, 54.9997, 44.8567, -30.8579, -53.714, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, 64.5595, 17.0849, 45.2641, -63.2115, -98.7838, -52.3835, -65.9849, 50.4909, 69.9812, -22.3447, -32.3809, -97.8394, -45.4997, 63.5391, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, -26.7797, 98.2463, -78.9606, -62.4907, -13.5348, -49.5058, -4.06369, 77.2982, -32.8221, 84.6766, -62.3829, 58.8673, -75.8509, -95.3497, -79.7642, 67.1185, 34.4278, 34.5305, -76.1646, 53.1497, -12.3158, -42.9392, 59.2962, -4.12072, 47.1292, -17.0687, -78.0087, -59.4565, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 52.7964, -0.932984, 70.8957, -68.1249, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, 66.1438, -22.2332, 84.0665, -7.86752, -2.38648, 1.37756, -98.691, -35.847, -15.2647, -85.2363, -54.3978, -46.6612, 99.3826, -75.7728, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -80.625, 19.6418, 51.0559, 35.3667, -56.306, -41.2088, 0.955906, -85.7743, 8.18112, 36.4615, -0.572343, -16.0821, 36.0277, -4.61647, 26.5385, 88.6082, 9.17454, 44.8951, -42.173, 51.5339, 5.93139, 93.7096, -68.8219, 68.2573, -67.325, -88.4579, -56.8873, -75.8117, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -57.8738, 63.1893, -40.4731, -1.00914, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, 26.8912, -82.1719, 26.8713, -56.5484, 35.4743, -8.59957, -12.4709, 27.8249, 76.6877, -27.5806, 63.2649, 66.1106, 15.8328, 9.19251, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -64.1341, 91.4046, 26.7268, -92.1002, -34.7002, -6.41819, -18.15, 12.207, 48.6667, -39.4883, -21.0939, -50.3433, 58.2913, 7.64983, -82.6098, -89.6739, -25.9494, 82.4803, 20.8037, 21.483, -29.0788, 31.7695, 50.462, -83.7715, 63.4177, 52.7679, -90.2271, 16.1258, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -23.039, -52.7595, -97.0177, -13.1399, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, -49.0936, -43.5545, 47.3574, 97.1801, 43.4392, 22.7347, 12.6125, 63.7829, 22.3428, 53.4543, -91.4307, 45.6971, -92.1851, -81.4774, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, -61.0205, -8.00347, -60.0369, 56.2894, -38.1932, 17.976, 5.82004, 4.41524, -52.2192, 93.1915, 21.1114, -29.3558, -18.5685, 20.7356, -4.71108, -0.947533, 73.0143, 62.5668, 96.1632, 41.4265, -89.503, -83.7747, -97.6047, -60.7304, 28.9736, -42.6681, 55.2584, 59.1584, -14.6596, -41.1826, 48.8083]
输入数据(sel_Gm): 
[60, 32, 10, 7, 42, 52, 38, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 15, 18, 83, 41, 43, 91, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 77, 80, 58, 34, 76, 44, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 45, 48, 31, 30, 56, 43, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
输出数据(dst_gm): 
[35.8789, 44.0334, -50.6124, -72.3737, -33.7107, -83.4001, -59.8013, 71.1663, 46.3484, 8.56818, -59.4716, 6.07412, -39.0137, -53.1848, 17.0849, 45.2641, -63.2115, 6.53722, -52.3835, 37.6655, 50.4909, 69.9812, -22.3447, -32.3809, 23.0464, 28.5975, -46.3033, -69.3535, 43.3368, 98.8541, -77.2888, 1.02385, 20.4965, 10.7653, 98.2463, -65.4774, -62.4907, -89.6462, -49.5058, -4.06369, 77.2982, -32.8221, -97.9312, -62.3829, 19.8888, -64.0522, -95.3497, -79.7642, 67.1185, 94.274, -72.868, -76.1646, 53.1497, 87.0729, -42.9392, 59.2962, -4.12072, 4.84555, -17.0687, -93.7445, 71.8209, -98.9565, -54.0959, 56.5437, -74.3328, 77.2781, 18.1016, -27.7082, -67.0339, -11.9576, 85.895, 25.4119, 71.9202, -73.1287, 63.6916, 21.4303, 66.0614, -67.2079, -22.2332, 84.0665, -7.86752, 42.1465, 1.37756, 82.3653, -35.847, -15.2647, -85.2363, -54.3978, 48.1772, -27.7686, 26.1484, -31.2539, 97.9558, 92.4507, 80.2871, -60.8802, -82.0434, -23.7061, 19.6418, -83.5848, 35.3667, 63.3315, -41.2088, 0.955906, -85.7743, 8.18112, 33.9888, -0.572343, -30.4885, 26.3331, -4.61647, 26.5385, 88.6082, 50.3262, 66.612, -42.173, 51.5339, 71.4549, 93.7096, -68.8219, 68.2573, 59.1608, -88.4579, -81.3639, -37.2619, -40.5576, -98.378, 32.3699, 64.6693, -73.1523, -39.3391, -33.7141, -88.6628, -94.1271, -80.6115, 44.2928, 76.6212, -29.3298, -58.1212, 83.3083, -20.6412, -30.1169, -82.1719, 26.8713, -56.5484, -37.6203, -8.59957, 5.09318, 27.8249, 76.6877, -27.5806, 63.2649, 11.9436, -22.4703, -58.5243, -79.6418, -9.31359, 63.7053, -8.37093, 55.6421, -99.0591, -85.9841, 91.4046, 48.9581, -92.1002, 2.42074, -6.41819, -18.15, 12.207, 48.6667, 51.6616, -21.0939, 74.2754, -51.1623, 7.64983, -82.6098, -89.6739, 79.6587, -31.1646, 20.8037, 21.483, -78.7341, 31.7695, 50.462, -83.7715, -80.3328, 52.7679, -77.979, -76.6814, -61.4531, -61.7242, 25.0575, -97.8702, 26.9708, -5.50804, -83.4637, -56.8385, 51.5406, -47.6936, -29.7551, 88.9603, -82.1242, -56.6307, 91.7884, 0.0381027, 5.28754, -43.5545, 47.3574, 97.1801, 70.5489, 22.7347, -53.5746, 63.7829, 22.3428, 53.4543, -91.4307, 19.083, -81.5139, 87.8383, 35.8835, -33.043, -79.7464, 69.0971, -82.6252, -63.0042, 56.4225, -8.00347, 91.5082, 56.2894, -81.7476, 17.976, 5.82004, 4.41524, -52.2192, -16.5589, 21.1114, 48.6597, -8.44998, 20.7356, -4.71108, -0.947533, 19.407, 88.1542, 96.1632, 41.4265, 59.5144, -83.7747, -97.6047, -60.7304, -31.3635, -42.6681, -26.9334, 79.3798, -14.6596, -41.1826, 48.8083]

样例模板

#include "kernel_operator.h"
namespace AscendC {
class KernelSelect {
public:
    __aicore__ inline KernelSelect() {}
    __aicore__ inline void Init(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm, __gm__ uint8_t* selGm,
        __gm__ uint8_t* dstGm)
    {
        src0Global.SetGlobalBuffer((__gm__ float*)src0Gm);
        src1Global.SetGlobalBuffer((__gm__ float*)src1Gm);
        selMaskGlobal.SetGlobalBuffer((__gm__ uint8_t*)selGm);
        dstGlobal.SetGlobalBuffer((__gm__ float*)dstGm);
        pipe.InitBuffer(inQueueSrc0, 1, dataSize * sizeof(float));
        pipe.InitBuffer(inQueueSrc1, 1, dataSize * sizeof(float));
        pipe.InitBuffer(inQueueSelMask, 1, selDataSize * sizeof(uint8_t));
        pipe.InitBuffer(outQueueDst, 1, dataSize * sizeof(float));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        LocalTensor<float> src0Local = inQueueSrc0.AllocTensor<float>();
        LocalTensor<float> src1Local = inQueueSrc1.AllocTensor<float>();
        LocalTensor<uint8_t> selMaskLocal = inQueueSelMask.AllocTensor<uint8_t>();
        DataCopy(src0Local, src0Global, dataSize);
        DataCopy(src1Local, src1Global, dataSize);
        DataCopy(selMaskLocal, selMaskGlobal, selDataSize);
        inQueueSrc0.EnQue(src0Local);
        inQueueSrc1.EnQue(src1Local);
        inQueueSelMask.EnQue(selMaskLocal);
    }
    __aicore__ inline void Compute()
    {
        LocalTensor<float> src0Local = inQueueSrc0.DeQue<float>();
        LocalTensor<float> src1Local = inQueueSrc1.DeQue<float>();
        LocalTensor<uint8_t> maskLocal = inQueueSelMask.DeQue<uint8_t>();
        LocalTensor<float> dstLocal = outQueueDst.AllocTensor<float>();

        Select(dstLocal, maskLocal, src0Local, src1Local, SELMODE::VSEL_CMPMASK_SPR, dataSize);

        outQueueDst.EnQue<float>(dstLocal);
        inQueueSrc0.FreeTensor(src0Local);
        inQueueSrc1.FreeTensor(src1Local);
        inQueueSelMask.FreeTensor(maskLocal);
    }
    __aicore__ inline void CopyOut()
    {
        LocalTensor<float> dstLocal = outQueueDst.DeQue<float>();
        DataCopy(dstGlobal, dstLocal, dataSize);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    TPipe pipe;
    TQue<QuePosition::VECIN, 1> inQueueSrc0, inQueueSrc1, inQueueSelMask;
    TQue<QuePosition::VECOUT, 1> outQueueDst;
    GlobalTensor<float> src0Global, src1Global, dstGlobal;
    GlobalTensor<uint8_t> selMaskGlobal;
    uint32_t dataSize = 256;
    uint32_t oneSelectDataSize = 256/ sizeof(float);
    uint32_t selDataSize = dataSize / oneSelectDataSize * 32;
    (模式1和模式2时,uint32_t selDataSize = dataSize / 8;)
};
} // namespace AscendC
extern "C" __global__ __aicore__ void main_sel_demo(__gm__ uint8_t* src0Gm, __gm__ uint8_t* src1Gm,
    __gm__ uint8_t* selGm, __gm__ uint8_t* dstGm)
{
    AscendC::KerneSelect op;
    op.Init(src0Gm, src1Gm, selGm, dstGm);
    op.Process();
}
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词