下载
中文
注册

FusedMulAddRelu

功能说明

按元素将src0Local和dstLocal相乘并再加上src1Local,将结果和0作比较,取较大值,最终结果存放进dstLocal中。计算公式如下,其中PAR表示矢量计算单元一个迭代能够处理的元素个数:

函数原型

  • tensor前n个数据计算

    template <typename T> __aicore__ inline void FusedMulAddRelu(const LocalTensor<T>& dstLocal, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, const int32_t& calCount);

  • tensor高维切分计算
    • mask逐bit模式

      template <typename T, bool isSetMask = true> __aicore__ inline void FusedMulAddRelu(const LocalTensor<T>& dstLocal, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, uint64_t mask[2], const uint8_t repeatTimes, const BinaryRepeatParams& repeatParams)

    • mask连续模式

      template <typename T, bool isSetMask = true> __aicore__ inline void FusedMulAddRelu(const LocalTensor<T>& dstLocal, const LocalTensor<T>& src0Local, const LocalTensor<T>& src1Local, uint64_t mask, const uint8_t repeatTimes, const BinaryRepeatParams& repeatParams);

参数说明

表1 模板参数说明

参数名

描述

T

操作数数据类型。

isSetMask

是否在接口内部设置mask。

  • true,表示在接口内部设置mask。
  • false,表示在接口外部设置mask,开发者需要使用SetVectorMask接口设置mask值。这种模式下,本接口入参中的mask值必须设置为MASK_PLACEHOLDER。
表2 参数说明

参数名

输入/输出

描述

dstLocal

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float/

src0Localsrc1Local

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

两个源操作数的数据类型需要与目的操作数保持一致。

Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float/

calCount

输入

输入数据元素个数。

mask

输入

mask用于控制每次迭代内参与计算的元素。该参数的详细讲解请参考mask参数

  • 连续模式:表示前面连续的多少个元素参与计算。数据类型为uint64。取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask∈[1, 128];当操作数为32位时,mask∈[1, 64]。
  • 逐bit模式:可以按位控制哪些元素参与计算,bit位的值为1表示参与计算,0表示不参与。参数类型为长度为2的uint64_t类型数组。

    例如,mask=[8, 0],8=0b1000,表示仅第4个元素参与计算。

    参数取值范围和操作数的数据类型有关,数据类型不同,每次迭代内能够处理的元素个数最大值不同。当操作数为16位时,mask[0]、mask[1]∈[0, 264-1];当dst/src为32位时,mask[1]为0,mask[0]∈[0, 264-1]

repeatTimes

输入

重复迭代次数。矢量计算单元,每次读取连续的256 Bytes数据进行计算,为完成对输入数据的处理,必须通过多次迭代(repeat)才能完成所有数据的读取与计算。repeatTimes表示迭代的次数。该参数的详细讲解请参考基础API通用说明

repeatParams

输入

控制操作数地址步长的数据结构。结构体内包含操作数相邻迭代间相同datablock的地址步长,操作数同一迭代内不同datablock的地址步长等参数。

该数据结构的定义请参考BinaryRepeatParams

相邻迭代间相同datablock的地址步长参数的详细说明请参考Repeat stride(相邻迭代间相同datablock的地址步长);同一迭代内不同datablock的地址步长参数请参考Block stride(同一迭代内不同datablock的地址步长)

返回值

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

Atlas推理系列产品AI Core

约束说明

  • 使用高维切分计算接口时,节省地址空间,开发者可以定义一个Tensor,供源操作数与目的操作数同时使用(即地址重叠),相关约束如下:
    • 单次迭代内,要求源操作数和目的操作数之间100%重叠,不支持部分重叠。
    • 多次迭代间,第N次目的操作数是第N+1次源操作数的情况下,是不支持地址重叠的,因为第N+1次依赖第N次的结果。

调用示例

本样例中只展示Compute流程中的部分代码。如果您需要运行样例代码,请将该代码段拷贝并替换双目指令样例模板更多样例中的Compute函数即可。

  • 高维切分计算接口样例-mask连续模式
    uint64_t mask = 128;
    // repeatTimes = 2, 一次迭代计算128个数, 共计算256个数
    // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内数据连续读取和写入
    // dstRepStride, src0RepStride, src1RepStride = 8, 相邻迭代间数据连续读取和写入
    FusedMulAddRelu(dstLocal, src0Local, src1Local, mask, 2, { 1, 1, 1, 8, 8, 8 });
  • 高维切分计算接口样例-mask逐bit模式
    uint64_t mask[2] = { UINT64_MAX, UINT64_MAX };
    // repeatTimes = 2, 一次迭代计算128个数, 共计算256个数
    // dstBlkStride, src0BlkStride, src1BlkStride = 1, 单次迭代内数据连续读取和写入
    // dstRepStride, src0RepStride, src1RepStride = 8, 相邻迭代间数据连续读取和写入
    FusedMulAddRelu(dstLocal, src0Local, src1Local, mask, 2, { 1, 1, 1, 8, 8, 8 });
  • tensor前n个数据计算样例
    FusedMulAddRelu(dstLocal, src0Local, src1Local, 256);
结果示例如下:
输入数据(src0Local): 
[ 77.25   59.84  -39.5    19.3   -45.88   11.81  -60.75  -99.6   -12.164
  25.75  100.    -68.    -48.12   -4.504 -78.5   -53.4    71.56  -33.5
 -78.75   81.6    61.25   -6.414 -31.77  -50.38   84.9   -75.56   89.06
 -50.7    93.     99.75    8.1    58.34  -47.3    40.78  -98.94   41.34
  98.5   -21.11   30.7    39.3   -82.06  -34.88  -51.44  -91.7   -33.12
  71.2   -66.3   -60.2    52.84  -75.9   -18.31  -85.3    92.5    62.22
  92.2    87.44  -77.56   26.64   31.47  -60.97   85.8   -48.47   74.2
  42.62  -49.8    77.56  -15.     47.5   -26.66  -97.1    -3.28   99.25
 -95.4    72.    -16.97  -50.12  -68.5    75.1    63.2   -34.72   79.56
  97.1   -60.47  -99.3   -59.5   -54.2   -23.62  -10.74   74.8   -18.3
   3.81  -72.2    15.21   31.55    9.5    22.08  -88.     77.5     7.3
  68.44  -40.7   -27.53   39.03   -6.84   43.28    4.81   26.75   -8.59
  95.44  -40.22   98.94   74.4   -84.8    -8.72   42.44   63.1    14.14
 -93.25   22.77  -39.1   -18.75  -26.22  -56.62   -8.43   25.14   43.66
  40.     63.44   87.94  -94.75   36.44   24.17  -91.2    33.66  -67.1
  52.6    81.     32.88   69.25   23.27   30.23   82.56   49.3   -59.78
  81.2    28.94   45.03  -49.22  -99.2    -6.19    1.262 -94.3    83.6
 -17.7   -61.     47.4   -50.62    1.013  32.97   87.7    -5.93  -13.36
  20.45  -18.45   -9.2    33.03  -50.7   -89.1   -93.4    10.3   -60.72
  44.84   12.664  39.1   -81.9   -37.1     5.215 -91.9    85.44   76.9
 -70.5   -39.22   59.8   -11.78    4.875 -98.     83.94  -31.03  -58.84
 -38.56   39.44   97.    -53.     97.75  -69.25   54.4    27.03  -42.66
  24.94  -56.97   -7.105  71.3    15.28  -39.84   79.7    51.28   76.25
   7.203  77.56  -15.26  -74.06   59.38   64.5   -63.3   -34.38  -24.98
 -54.03  -23.53   -9.875  46.12   34.94  -67.9    96.25   65.3   -88.06
  66.5    65.6   -76.75   77.2    29.73   63.28   41.25  -20.84   34.6
  37.78  -72.44   58.25    9.4   -50.56   29.28   28.2     4.438 -10.35
  79.06   62.75  -11.51  -15.16   -4.055 -10.016   4.887  94.1   -50.25
  15.055  18.45    8.78   28.25 ]
输入数据(src1Local): 
[-20.95    -1.547  -46.25    27.14    43.16    21.44    65.6     78.25
 -72.9     25.39     2.133   76.3     17.56   -96.9    -41.16   -10.734
  -2.305   74.5    -42.2    -63.44   -67.5    -11.62   -78.44    52.66
  18.8     45.16   -35.94    82.44   -53.9     10.54   -86.9    -55.9
   7.22    85.3     14.32   -57.53    52.9    -92.4     45.      31.64
  40.56    42.06    70.25    -1.671  -24.27    74.9     34.56    96.44
  19.27    -9.76    96.1    -85.8     26.34    83.1    -66.8     63.03
  29.69    39.22   -55.      53.2    -68.75     0.5854  57.9     -8.53
 -39.62    36.06    66.     -34.56    50.4     59.16   -66.1     54.06
  60.1    -30.06    88.3    -72.06   -22.34   -52.34    39.4    -59.28
  10.266  -54.7     59.44    14.49   -86.94    98.7     10.25   -75.75
 -70.06     6.758  -43.66    17.42   -81.3     23.22    33.5     50.72
  -5.547   44.78    71.2     28.7     92.06   -68.3    -89.     -69.2
  28.95   -63.72    49.53    15.29    23.06    89.44    87.4    -87.3
  91.44    80.3     22.9     -1.456  -69.94   -99.3     40.47    11.984
 -28.7    -32.22   -84.     -17.67    70.3    -79.     -83.6     42.56
 -67.     -46.53   -84.4     -0.7954 -15.34   -71.7    -81.25    24.62
 -63.7     81.8     85.44    73.2    -99.1     15.29   -85.6     78.5
 -54.12    47.03    46.38   -31.78    58.84   -21.9     83.6     60.88
 -53.4    -52.3    -47.3     20.2     -8.055  -67.3    -52.16   -20.8
  11.1     36.8    -22.06   -57.53   -42.34   -59.66    -0.651  -34.25
 -19.33   -80.5    -33.62    28.31   -89.7    -41.5     -8.9    -58.44
 -54.16   -87.4    -63.66     6.664   -7.562   95.06    85.3     82.25
  27.75   -25.39    88.06   -94.56    86.8     68.     -26.05     8.14
 -42.94   -35.      29.12   -64.9     38.7     29.42    15.31   -28.3
 -56.47    79.44   -96.8    -64.8     81.5    -69.3    -23.36   -95.75
 -56.62    77.75   -23.1     67.94   -29.03   -14.69     0.5996 -17.05
 -73.75    95.6     79.25   -61.7      5.363   88.1    -15.086   58.66
 -20.7     90.9     42.7    -51.22     0.3909  77.1    -69.3     70.75
  89.7     95.1     56.1    -72.4    -99.9    -11.64     1.984  -86.6
 -22.22   -24.44    12.5    -46.66    77.7     46.28   -48.3    -19.03
 -16.61    64.44    -4.758  -20.9     80.2     67.      12.9    -71.75  ]
输入数据(dstLocal): 
[-69.5     -8.8    -38.3     19.9    -70.      71.75   -23.17   -13.48
  70.     -31.16    29.6     -2.914   47.6     83.5    -33.75    44.25
 -87.9      3.54    46.1     53.8     -4.164  -83.56   -35.6     58.47
 -43.56   -38.03    63.72     4.9    -25.56    73.7     34.7     96.3
 -58.34   -56.78   -39.44    82.5     34.4    -30.34    -4.72   -97.75
   2.406  -83.1     17.06    41.34   -45.84   -35.6     19.58    27.45
   4.258  -95.25    -7.418  -10.34    99.06   -84.3     -2.508  -14.805
  44.72    27.14    64.44   -16.69   -48.28    73.94    72.2     71.75
 -72.7    -43.72   -65.3    -17.75    16.02    35.     -57.6     87.75
  91.6    -44.53   -82.     -54.44   -23.61   -71.5     86.75   -83.5
  39.16   -60.2     51.22   -63.2     61.8     53.9    -13.016   78.4
 -55.1    -59.94   -88.25   -12.35    83.4     94.44   -38.28   -59.38
 -98.06    54.94   -11.664  -22.66   -94.3     65.06    89.56   -74.44
  17.98    29.78   -98.06    95.6     -4.26   -65.5      2.662  -76.9
  26.03   -22.84    40.6    -78.6     58.56   -25.34   -73.9     -2.59
 -86.9     52.75    69.7     48.66    -8.95   -57.84   -87.3    -43.84
 -26.36    72.3      0.1203  57.88    96.56    43.6     94.44    76.1
 -58.47    82.1    -46.03    82.2    -67.94   -80.75    91.4    -46.4
  77.56   -39.25    92.25   -83.8    -17.25   -42.72    -3.533   18.4
 -17.6     80.94    64.      72.6     53.66   -19.17   -34.88    73.94
 -80.1    -65.8     -8.6     62.     -64.9    -31.45   -25.12   -98.1
  75.75    -6.44    66.75     4.11   -62.22   -74.56    17.61    85.4
  -0.5825 -39.28    -0.2615 -88.5     67.1    -78.25    -7.7    -82.75
  98.1     62.12    73.9    -98.8    -78.6    -88.44    49.97    98.6
 -14.484  -18.44   -15.57    87.8    -46.56   -71.75    61.7     16.42
 -71.56    91.4    -35.03   -61.78   -10.39   -44.1     96.44    92.7
   5.785  -59.75   -57.22   -63.1    -71.6     87.06     3.691  -69.44
 -15.84   -73.56     4.08   -27.5     41.7     88.6     21.4    -73.3
  91.9     29.28   100.     -19.3    -52.88    28.23   -42.72    68.25
  60.9    -48.4    -31.86    74.25   -40.8    -21.61    78.3    -90.56
 -56.47    -9.51    80.8    -32.28    11.18    34.34    65.9     98.8
  24.83    82.9    -92.6    -67.06    51.9     -8.68   -55.03   -75.    ]
输出数据(dstLocal): 
[   0.      0.   1467.    411.5  3256.    869.   1474.   1421.      0.
    0.   2962.    274.5     0.      0.   2608.      0.      0.      0.
    0.   4328.      0.    524.5  1053.      0.      0.   2920.   5640.
    0.      0.   7364.    194.1  5564.   2768.      0.   3916.   3352.
 3442.    548.      0.      0.      0.   2940.      0.      0.   1495.
    0.      0.      0.    244.2  7220.    232.    796.5  9184.      0.
    0.      0.      0.    762.   1973.   1071.      0.      0.   5412.
 3050.   3580.      0.   1046.      0.      0.      0.    122.75 8768.
    0.      0.   1479.   2656.   1595.      0.   5520.   2840.   3126.
    0.      0.   6292.      0.      0.    317.8     0.      0.   1104.
    0.    909.   1187.   3004.      0.      0.   8624.   4300.      0.
    0.   3930.      0.   3408.    440.    807.5    79.5     0.      0.
    0.   2724.    351.      0.      0.    279.5  1746.      0.    758.
 2264.      0.    113.25 1600.      0.      0.      0.      0.      0.
    0.      0.      0.      0.      0.   1398.      0.   1395.      0.
 4028.      0.   2782.      0.   1985.      0.      0.   4420.   2852.
 6240.      0.   4204.   4092.   1770.    242.6    79.2     0.      0.
    0.      0.   3462.      0.      0.      0.   6464.    486.    916.
    0.      0.    554.5     0.   1273.   8712.      0.      0.      0.
  212.5     0.      0.      0.      0.      0.   3520.      0.      0.
    0.   3164.      0.   1057.    506.      0.   6288.   2972.   4716.
 3478.   1945.   9576.    724.5     0.   1107.   4712.      0.   3090.
 1553.      0.    452.   6596.      0.   2398.      0.      0.   7328.
  572.    392.    989.   4212.      0.      0.      0.      0.   1718.
  782.   1827.     38.97    0.   1461.      0.   2045.      0.      0.
 2038.   6608.   1430.      0.    916.5     0.   2886.      0.      0.
    0.      0.      0.      0.      0.      0.      0.      0.      0.
    0.    779.      0.      0.      0.      0.    469.5     0.   3350.
  861.5     0.      0.      0.  ]