vec_reduce_max

功能说明

在所有的输入数据中找出最大值及最大值对应的索引位置。注意:如果有多个最大值,返回哪个最大值,请看注意事项。

函数原型

vec_reduce_max(mask, dst, src, work_tensor, repeat_times, src_rep_stride, cal_index=False)

参数说明

表1 参数说明

参数名称

输入/输出

含义

mask

输入

请参考表1中mask参数描述。

dst

输入

目的操作数,tensor起始element,起始地址要求4Byte对齐。

Tensor的scope为Unified Buffer

src

输入

源操作数,tensor起始element,起始地址对齐要求请见通用约束

Tensor的scope为Unified Buffer

work_tensor

输入

指令执行期间存储中间结果,用于内部计算所需操作空间,需特别注意空间大小,参见各指令注意事项。

repeat_times

输入

重复迭代次数。Scalar(int32)、立即数(int)、Expr(int32)。注意,推荐使用立即数,性能比较高。

src_rep_stride

输入

相邻迭代间,源操作数相同block地址步长。支持的数据类型为:Scalar(int16/int32/int64/uint16/uint32/uint64)、立即数(int)、Expr(int16/int32/int64/uint16/uint32/uint64)。

cal_index

输入

指定是否获取最值的索引,仅支持bool类型,默认值为False,取值:

  • True:同时获取最值和最值索引。
  • False:不获取索引,只获取最值。

dst、src和work_tensor的数据类型需保持一致。

Atlas 200/300/500 推理产品,dst、src和work_tensor支持的数据类型为:Tensor(float16)

Atlas 训练系列产品,dst、src和work_tensor支持的数据类型为:Tensor(float16)

返回值

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

注意事项

调用示例

  1. 【举例一】

    src, work_tensor, dst均为float16的tensor,src的shape为(65, 128),vec_reduce_max/vec_reduce_min的repeat_times为65。

    接口调用示例为:

    tik_instance.vec_reduce_max(128, dst, src, work_tensor, 65, 8, cal_index=True)

    此时work_tensor的空间计算过程为:

    elements_per_block = 16 (elements)
    elements_per_repeat = 128 (elements)
    it1_output_count = 2*65 = 130 (elements)
    
    def ceil_div(a_value, b_value):
        return (a_value + b_value - 1) // b_value
    
    it2_align_start = ceil_div(130, 16)*16 = 144 (elements)
    it2_output_count = ceil_div(130, 128)*2 = 4 (elements)
    it3_align_start = ceil_div(4, 16)*16 = 16 (elements)
    it3_output_count = ceil_div(4, 128)*2 = 2 (elements)

    三轮即可拿到最终的最值以及下标,需要的空间work_tensor为:it2_align_start + it3_align_start + it3_output_count = 144 + 16 + 2 = 162 (elements)

  2. 【举例二】

    src, work_tensor, dst均为float16的tensor,src的shape为(65, 128),vec_reduce_max/vec_reduce_min的repeat_times为scalar,值65。对于repeat_times为scalar或包含scalar的情况,需要做四轮计算。

    接口调用示例为:

    scalar = tik_instance.Scalar(init_value=65, dtype=”int32”)
    tik_instance.vec_reduce_max(128, dst, src, work_tensor, scalar, 8, cal_index=True)

    此时work_tensor的空间计算过程:

    elements_per_block = 16 (elements)
    elements_per_repeat = 128 (elements)
    it1_output_count = 2*65 = 130 (elements)
    
    def ceil_div(a_value, b_value):
        return (a_value + b_value - 1) // b_value
    
    it2_align_start = ceil_div(130, 16)*16 = 144 (elements)
    it2_output_count = ceil_div(130, 128)*2 = 4 (elements)
    it3_align_start = ceil_div(4, 16)*16 = 16 (elements)
    it3_output_count = ceil_div(4, 128)*2 = 2 (elements)
    it4_align_start = ceil_div(2, 16)*16 = 16 (elements)
    it4_output_count = ceil_div(2, 128)*2 = 2(elements)

    对于repeat_times为scalar或包含scalar的情况,虽然第三轮就能拿到结果,但是由于在Python编译时无法获取scalar的值,因此还是跑了四轮,需要的空间work_tensor为:it2_align_start + it3_align_start + it4_align_start + it4_output_count = 144 + 16 + 16 + 2 = 178 (elements)

  3. 【举例三】

    src, work_tensor, dst均为float32的tensor,src的shape为(65, 64),vec_reduce_max/vec_reduce_min的repeat_times为65。

    接口调用示例为:

    tik_instance.vec_reduce_max(64, dst, src, work_tensor, 65, 8, cal_index=True)

    此时work_tensor的空间计算过程为:

    elements_per_block = 8 (elements)
    elements_per_repeat = 64 (elements)
    it1_output_count = 2*65 = 130 (elements)
    
    def ceil_div(a_value, b_value):
        return (a_value + b_value - 1) // b_value
    
    it2_align_start = ceil_div(130, 8)*8 = 136 (elements)
    it2_output_count = ceil_div(130, 64)*2 = 6 (elements)
    it3_align_start = ceil_div(6, 8)*8 = 8 (elements)
    it3_output_count = ceil_div(6, 64)*2 = 2 (elements)

    此时三轮即可拿到最终的最值以及下标,需要的空间work_tensor为:it2_align_start + it3_align_start + it3_output_count = 136 + 8 + 2 = 146 (elements)

  4. 【举例四】

    src, work_tensor, dst均为float32的tensor,src的shape为(65, 64),vec_reduce_max/vec_reduce_min的repeat_times为scalar,值65。对于repeat_times为scalar或包含scalar的情况,需要做四轮计算。

    接口调用示例为:

    scalar = tik_instance.Scalar(init_value=65, dtype=”int32”)
    tik_instance.vec_reduce_max(64, dst, src, work_tensor, scalar, 8, cal_index=True)

    此时work_tensor的空间计算过程为:

    elements_per_block = 8 (elements)
    elements_per_repeat = 64 (elements)
    it1_output_count = 2*65 = 130 (elements)
    
    def ceil_div(a_value, b_value):
        return (a_value + b_value - 1) // b_value
    
    it2_align_start = ceil_div(130, 8)*8 = 136 (elements)
    it2_output_count = ceil_div(130, 64)*2 = 6 (elements)
    it3_align_start = ceil_div(6, 8)*8 = 8 (elements)
    it3_output_count = ceil_div(6, 64)*2 = 2 (elements)
    it4_align_start = ceil_div(2, 8)*8 = 8 (elements)
    it4_output_count = ceil_div(2, 64)*2 = 2(elements)

    对于repeat_times为scalar或包含scalar的情况,虽然第三轮就能拿到结果,但是由于在Python编译时无法获取scalar的值,因此还是跑了四轮,需要的空间work_tensor为:it2_align_start + it3_align_start + it4_align_start + it4_output_count = 136 + 8 + 8 + 2 = 154 (elements)

  5. 完整示例
    from tbe import tik
    tik_instance = tik.Tik()
    src_gm = tik_instance.Tensor("float16", (256,), name="src_gm", scope=tik.scope_gm)
    dst_gm = tik_instance.Tensor("float16", (16,), name="dst_gm", scope=tik.scope_gm)
    src_ub = tik_instance.Tensor("float16", (256,), name="src_ub", scope=tik.scope_ubuf)
    dst_ub = tik_instance.Tensor("float16", (16,), name="dst_ub", scope=tik.scope_ubuf)
    work_tensor_ub = tik_instance.Tensor("float16", (18,), tik.scope_ubuf, "work_tensor_ub")
    # 将用户输入数据从gm搬运到ub
    tik_instance.data_move(src_ub, src_gm, 0, 1, 16, 0, 0)
    # 给dst_ubuf赋初始值0,这样输出结果更加直观
    tik_instance.vec_dup(16, dst_ub, 0, 1, 1)
    tik_instance.vec_reduce_max(128, dst_ub, src_ub, work_tensor_ub, 2, 8, cal_index=True)
    # 将计算结果从ub搬运到gm
    tik_instance.data_move(dst_gm, dst_ub, 0, 1, 1, 0, 0)
    
    tik_instance.BuildCCE(kernel_name="vec_reduce_max", inputs=[src_gm], outputs=[dst_gm])

    结果示例:

    输入数据(src_gm):
    [-3.326   -6.883    3.607   -0.969   -0.179    2.254   -3.957    3.242
      6.133   -3.559    3.656   -9.88     2.19     4.707   -7.027   -3.598
     -3.264    4.44     6.04    -6.35     0.525   -6.492    0.341   -4.477
      1.375    6.484   -7.957   -1.243   -9.586   -2.871   -6.688    2.088
      5.      -1.808   -5.62     9.47     1.311    2.69     8.58     9.3
      5.754   -6.25     4.516   -6.6     -0.331   -8.586    4.844    9.81
      7.695   -0.332   -7.137   -2.79     2.66     5.316    8.72     1.954
      5.043   -7.816    1.207    2.508   -5.06    -1.697    8.5     -6.637
     -0.647   -1.211   -3.229   -3.074    7.89     5.043   -3.059   -0.7544
      9.484   -2.809   -7.145   -1.051    9.45     7.688    6.695   -2.318
     -0.3562  -0.674    1.736    2.994   -2.018   -2.605   -7.113    6.09
     -1.766    6.574   -4.47     7.367   -7.93     6.88     7.83     6.527
      5.816   -3.135    6.195   -6.734   -8.85     1.705   -5.023    5.992
      6.062   -3.342    8.03    -0.748    0.9883   3.191    2.75     8.39
      9.17    -5.887    1.378   -8.77    -9.05    -3.11    -7.203    9.79
      9.64     3.945    9.32     7.812    7.066    0.664    5.234   -4.61
     -3.559   -7.73     1.441    5.434    8.23     4.785   -1.231    8.03
      0.293   -0.1658  -5.48    -3.293    8.89    -7.926   -9.66     1.597
      0.5396   9.25    -6.74     7.086   -0.954    8.96     2.318   -2.395
     -9.19    -6.176   -4.297   -7.812   -1.787   -5.39     6.5      9.055
     -0.9556   2.4      2.092    7.35     0.7017   1.548   -2.637   -5.145
     -2.938    5.617   -3.451    7.5     -5.426   -7.62     7.535   -9.14
     -8.7     -3.436    2.283   -6.18     2.836    5.707   -1.356    8.664
      1.625   -3.717    1.478   -6.67    -4.023    2.652    4.805   -8.25
      2.63    -1.394   -3.227    1.595    7.49     7.574   -3.053   -1.841
     -7.06     0.4524  -5.71     5.37     8.72     8.51     4.836   -5.05
     -7.043    5.188   -5.332    5.62    -0.6465   5.773    8.53     7.793
     -4.215    7.47    -2.451    8.18     5.543   -7.367    7.105   -0.10364
      4.465    0.3362   0.9287   2.447   -9.87     7.844    2.084    4.527
      7.582   -3.217   -5.695   -6.375    0.627    2.24     6.625   -9.55
     -5.613    7.055    9.48    -6.613    5.49     5.066    4.117    9.516
     -4.594   -0.781    2.102    9.94     6.49    -7.82     0.11975  3.146  ]
    
    输出数据(dst_gm):
    [9.938e+00 1.496e-05 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
     0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
     0.000e+00 0.000e+00]