获取Tensor部分数据
功能说明
通过Tensor数组下标,获得Tensor内部分数据,形成新的Tensor。
函数原型
__getitem__(index_in)
参数说明
参数名称 |
输入/输出 |
含义 |
---|---|---|
index_in |
输入 |
Tensor下标,包括如下类型:
|
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas推理系列产品AI Core
Atlas推理系列产品Vector Core
Atlas A2训练系列产品/Atlas 800I A2推理产品
Atlas 200/500 A2推理产品
注意事项
- index_in为立即数、Scalar、Expr、Slice类型时,若原Tensor是多维的,先将其reshape成一维,返回的Tensor为一维。
- dataA_UB[100]与dataA_UB[100:101:1]含义相同。
- 立即数类型的index_in要在[-dim, dim)之间,负数下标表示取倒数第abs(index_in)个元素,dim为Tensor元素个数。
- Scalar、Expr类型的index_in要在[0, dim)之间,dim为Tensor元素个数。
- Slice类型的index_in:当start、stop为立即数时要在[-dim, dim)之间,当start、stop为Scalar、Expr时要在[0, dim)之间,
- 用户在进行slice操作时,不同scope对齐要求不同,slice的start根据不同scope需要满足如下的要求:
- UB:32Byte对齐
- GM:无对齐要求
- 当Tensor的shape中包含变量时,若先改变该变量的值,再对Tensor进行切片,可能会导致切片得到的新Tensor所需的空间大于max_mem_size而报错。例如data_A的shape为[i]=[64], max_mem_size为128。data_B由data_A切片而来,两者的内存指向同一块,但data_B的shape变成了[i]=[128],所需的空间超出了max_mem_size,因而debug会报错“The space required exceeds max_mem_size”
from tbe import tik tik_instance = tik.Tik() i = tik_instance.Scalar(dtype="int32", name="i",init_value=64) data_A = tik_instance.Tensor("int16", (i, ), name="data_A", scope=tik.scope_gm, max_mem_size=128) i.set_as(128) data_B = data_A[:]
返回值
新的Tensor。
调用示例
from tbe import tik tik_instance = tik.Tik() A = tik_instance.Tensor("int64", (2, 4), name="A", scope=tik.scope_gm) tik_instance.tikdb.debug_print("A") data_1 = A[0] tik_instance.tikdb.debug_print("data_1") data_2 = A[0:] tik_instance.tikdb.debug_print("data_2") data_3 = A[0, 1] tik_instance.tikdb.debug_print("data_3") data_4 = A[0:, 1] tik_instance.tikdb.debug_print("data_4") # data_5 = A[0][1] # tik_instance.tikdb.debug_print("data_5") data_6 = A[0:][1] tik_instance.tikdb.debug_print("data_6") tik_instance.BuildCCE(kernel_name="getitem_sample", inputs=[A], outputs=[]) # 输入A的数据为: [[0, 1, 2, 3], [4, 5, 6, 7]] # A的打印结果为: [[0, 1, 2, 3], [4, 5, 6, 7]] # data_1表示A展平后的第0个元素,打印结果为: [0] # data_2表示A展平后从0开始的元素,打印结果为: [0, 1, 2, 3, 4, 5, 6, 7] # data_3表示A的第0行第1列元素,打印结果为: [[1]] # data_4表示A的第1列元素,打印结果为: [[1], [5]] # data_5表示A[0]的第1个元素,由于A[0]只有一个元素,因此会报错。 # data_6表示A[0:]的第1个元素,打印结果为: [1]
父主题: Tensor管理