下载
中文
注册

slice

功能说明

从输入tensor中根据begin与end获取切片数据。

例如:

  • 输入tensor为: [[1,2,3],[4,5,6],[7,8,9]]
  • begin为: [0,1]
  • end为: [2,2]

调用slice接口,则代表需要在tensor的第0维中截取索引范围为[0, 2)的数据,第1维中截取索引范围为[1, 2)的数据,结果如下所示:

slice_tensor = [[2], [5]]

函数原型

slice(tensor, begin, end, stride=None)

参数说明

  • tensor:待切片的数据,tvm.tensor类型。

    支持的数据类型:float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, bool。

  • begin:每个维度取值的索引起点,List类型,切片时,包含begin指向的值。

    支持的数据类型:int32, int64。

  • end:每个维度取值的索引终点,List类型,切片时,不包含end指向的值。

    支持的数据类型:int32, int64。

  • stride:预留参数,当前版本暂不支持,默认值为None,保持默认即可。

返回值

执行slice操作之后的Tensor,tvm.tensor类型

约束说明

  • begin与end的长度需要和输入tensor的shape维度保持一致。
  • begin与end的数据类型需保持一致,例如,都是int32或者都是int64,不支持混合使用。
  • begin、end与输入tensor的shape的需要满足如下条件:

    0<=begin[0]<end[0]<=shape[0]

  • 此接口暂不支持与其他TBE DSL计算接口混合使用。

支持的型号

Atlas 训练系列产品

Atlas 推理系列产品

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

  • 示例:输入tensor为:
    tensor = [[1,2,3,4,5],
              [6,7,8,9,10],
              [11,12,13,14,15],
              [16,17,18,19,20]]

    输入begin为:

    begin = [0,1]

    输入end为:

    end = [3,5]
    调用slice接口,根据begin和end获取数据切片,代码示例如下所示:
    from tbe import tvm
    from tbe import dsl
    tensor = tvm.placeholder((4, 5), dtype=dtype, name="params")
    begin = [0, 1]
    end = [3, 5]
    slice_tensor = dsl.slice(tensor,begin,end)

    以上代码的功能为获取tensor的数据切片,其中第0维,获取索引为0,1,2的数据;第1维,获取索引为1,2,3,4的数据。

    输出结果如下所示:

    slice_tensor = [[2,3,4,5],
                     [7,8,9,10],
                     [12,13,14,15]]