获取Tensor部分数据

功能说明

通过Tensor数组下标,获得Tensor内部分数据,形成新的Tensor。

函数原型

__getitem__(index_in)

参数说明

表1 参数说明

参数名称

输入/输出

含义

index_in

输入

Tensor下标,包括如下类型:

  • 立即数(int)、Scalar(int、uint)、Expr(int、uint)类型。例如:dataA_UB[100]。
  • Slice类型。切片的start、stop可以为立即数、Scalar、Expr类型,切片的step当前只能取立即数1。例如:dataA_UB[100:]、dataA_UB[100:200:1]。
  • Tuple类型。多维表示,维度和Tensor维度一致,每一维可以是立即数、Scalar、Expr、Slice类型,维度和维度间以逗号隔开。例如dataA_UB[0,0,0:100]。

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

注意事项

返回值

新的Tensor。

调用示例

from tbe import tik
tik_instance = tik.Tik()
A = tik_instance.Tensor("int64", (2, 4), name="A", scope=tik.scope_gm)
tik_instance.tikdb.debug_print("A")
data_1 = A[0]
tik_instance.tikdb.debug_print("data_1")
data_2 = A[0:]
tik_instance.tikdb.debug_print("data_2")
data_3 = A[0, 1]
tik_instance.tikdb.debug_print("data_3")
data_4 = A[0:, 1]
tik_instance.tikdb.debug_print("data_4")
# data_5 = A[0][1]
# tik_instance.tikdb.debug_print("data_5")
data_6 = A[0:][1]
tik_instance.tikdb.debug_print("data_6")
tik_instance.BuildCCE(kernel_name="getitem_sample", inputs=[A], outputs=[])

# 输入A的数据为:
[[0, 1, 2, 3], [4, 5, 6, 7]]
# A的打印结果为:
[[0, 1, 2, 3], [4, 5, 6, 7]]
# data_1表示A展平后的第0个元素,打印结果为:
[0]
# data_2表示A展平后从0开始的元素,打印结果为:
[0, 1, 2, 3, 4, 5, 6, 7]
# data_3表示A的第0行第1列元素,打印结果为:
[[1]]
# data_4表示A的第1列元素,打印结果为:
[[1], [5]]
# data_5表示A[0]的第1个元素,由于A[0]只有一个元素,因此会报错。
# data_6表示A[0:]的第1个元素,打印结果为:
[1]