gather_nd

功能说明

gather接口只能获取某一维度上的索引数据,而gather_nd接口可以获取多个维度上的索引数据,且针对同一维度的每组数据,可分别取不同的索引值。

详细功能介绍可参见调用示例

函数原型

gather_nd(params, indices, batch_dims=0)

参数说明

返回值

wrapped_tensor:执行gather_nd之后的Tensor,tvm.Tensor类型

约束说明

无。

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

调用示例

输入Tensor params为:

params = [[1,2,3],
          [4,5,6],
          [7,8,9]]

输入Tensor indices为:

indices = [[1, 2],
           [2, 0],
           [0, 1]]

batch_dims=0

调用gather_nd接口,根据indices的坐标值,获取params数据切片, 代码示例如下所示:
from tbe import tvm
from tbe import dsl
params = tvm.placeholder((3, 3), dtype=dtype, name="params")
indices = tvm.placeholder((3, 2), dtype=dtype, name="indices")
set_valued_tensor = dsl.gather_nd(params, indices, 0)

输出结果如下所示:

gather_nd_tensor = [6, 7, 2]