下载
中文
注册

gather

功能说明

获取输入Tensor的特定维度中指定索引的切片数据。

例如:

  • 输入Tensor params为:params = [[1,2,3],[4,5,6],[7,8,9]]
  • 输入Tensor indices为:indices = [2,0]
  • 输入axis = 0

调用此接口,则代表需要在params的第0维中,分别取索引为2和索引为0的数,结果如下所示:

gather_tensor = [[7,8,9], [1,2,3]]

函数原型

gather(params, indices, axis=None, batch_dims=0, impl_mode="support_out_of_bound_index")

参数说明

  • params:待切片的数据。输入Tensor,tvm.tensor类型。

    支持的数据类型:float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, bool。

  • indices:待取出数据的位置索引。输入Tensor,tvm.tensor类型,支持的数据类型:int32, int64。
  • axis:指定切片数据所在的维度。取值范围:[-p, p -1],且要大于等于batch_dims,其中p是params的维数,int类型。
  • batch_dims:可选参数,代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。例如,如果batch_dims=1,则代表在params的第一个轴上有一个外循环indices,详细功能可参见•示例:batch_dims为1

    batch_dims的取值范围为[-i, i],其中i是indices的维数。

    数据类型为int,且需要满足:batch_dims <=axis。

  • impl_mode:可选参数,算子处理时选择的模式,默认值为"support_out_of_bound_index",亦可设置为None。"support_out_of_bound_index"模式下,当索引值超范围(正常索引的范围为[-params.shape[axis], params.shape[axis]))时,对应的输出置零。其他模式下,当索引值超范围时,会报AI Core Error。

返回值

gather_tensor:执行gather之后的Tensor,tvm.tensor类型

约束说明

  • 若配置了batch_dims,即batch_dims≠0,则params与indices的batch轴的维度大小要保持一致。例如,若batch_dims=1,表示要在第一个轴上做批处理,则params与indices的轴0的维度大小要相同。
  • 此接口暂不支持与其他TBE DSL计算接口混合使用。

支持的型号

Atlas 训练系列产品

Atlas 推理系列产品

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

  • 示例:batch_dims为0

    输入Tensor params为:

    params = [[1,2,3],
              [4,5,6],
              [7,8,9]]

    输入Tensor indices为:

    indices = [2,0]

    输入axis = 1, batch_dims=0

    调用gather接口,根据indices tensor的值,在axis轴获取params数据切片,代码示例如下所示:
    from tbe import tvm
    from tbe import dsl
    params = tvm.placeholder((3,3), dtype=dtype, name="params")
    indices = tvm.placeholder((2,), dtype=dtype, name="indices")
    set_valued_tensor = dsl.gather(params,indices,0, 0)

    以上代码的功能为获取params 0轴上index为2和0的数据切片,输出结果如下所示:

    gather_tensor = [[3, 1],
                     [6, 4],
                     [9, 7]]
  • 示例:batch_dims为1

    输入Tensor params为:

    params = [[1,2,3],
              [4,5,6],
              [7,8,9]]

    输入Tensor indices为:

    indices = [[1],
               [2],
               [0]]

    输入axis = 1, batch_dims=1

    调用gather接口,代码示例如下所示:
    from tbe import tvm
    from tbe import dsl
    params = tvm.placeholder((3, 3), dtype=dtype, name="params")
    indices = tvm.placeholder((3, 1), dtype=dtype, name="indices")
    set_valued_tensor = dsl.gather(params, indices, 1, 1)

    因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将params[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取params数据切片, 其中i为batch轴的坐标。

    输出结果如下所示:

    gather_tensor = [[2,],
                     [6,],
                     [7,]]