gather_nd(params, indices, batch_dims=0)
wrapped_tensor:执行gather_nd之后的Tensor,tvm.Tensor类型
无。
Atlas 200/300/500 推理产品
Atlas 训练系列产品
输入Tensor params为:
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [[1, 2], [2, 0], [0, 1]]
batch_dims=0
from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 2), dtype=dtype, name="indices") set_valued_tensor = dsl.gather_nd(params, indices, 0)
输出结果如下所示:
gather_nd_tensor = [6, 7, 2]