获取输入Tensor的特定维度中指定索引的切片数据。
调用此接口,则代表需要在params的第0维中,分别取索引为2和索引为0的数,结果如下所示:
gather_tensor = [[7,8,9], [1,2,3]]
gather(params, indices, axis=None, batch_dims=0)
支持的数据类型:float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, bool。
batch_dims的取值范围为[-i, i],其中i是indices的维数。
数据类型为int,且需要满足:batch_dims <=axis。
wrapped_tensor:执行gather之后的Tensor,tvm.Tensor类型
若配置了batch_dims,即batch_dims≠0,则params与indices的batch轴的维度大小要保持一致。例如,若batch_dims=1,表示要在第一个轴上做批处理,则params与indices的轴0的维度大小要相同。
Atlas 200/300/500 推理产品
Atlas 训练系列产品
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [2,0]
输入axis = 1, batch_dims=0
from tbe import tvm from tbe import dsl params = tvm.placeholder((3,3), dtype=dtype, name="params") indices = tvm.placeholder((2,), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params,indices,0, 0)
以上代码的功能为获取params 0轴上index为2和0的数据切片,输出结果如下所示:
gather_tensor = [[3, 1], [6, 4], [9, 7]]
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [[1], [2], [0]]
输入axis = 1, batch_dims=1
from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 1), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params, indices, 1, 1)
因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将params[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取params数据切片, 其中i为batch轴的的坐标。
输出结果如下所示:
gather_tensor = [[2,], [6,], [7,]]