inplace_update
功能说明
inplace_update实现将tensor lhs的指定行用rhs来替换计算。
例如:
res = lhs res[ids,:] = rhs return res
函数原型
inplace_update(lhs, inplace_ids, rhs)
参数说明
- lhs:输入左tensor。
- inplace_ids:为int、list类型,数值大于等于0且不能大于lhs第一维的整型。其长度必须和输入rhs的第一维的长度相同。
- rhs:输入右tensor或scalar,除了第一维外,其他维度必须与lhs相同。当inplace_ids为int型单个数时,rhs比lhs差一个维度。例如:lhs为(10,1024),inplace_ids为[5],rhs为(1,1024);lhs为(10,1024),inplace_ids为5,rhs为(1024,)。
- 支持的数据类型:
Atlas 200/300/500 推理产品:支持的数据类型有float16、float32、int32
Atlas 训练系列产品:支持的数据类型有float16、float32、int32
Atlas 推理系列产品:支持的数据类型有float16、float32、int32
Atlas 200/500 A2推理产品:支持的数据类型有float16、float32、int32
Atlas A2训练系列产品/Atlas 800I A2推理产品:支持的数据类型有float16、float32、int32
返回值
res_tensor:表示计算后的tensor。
约束说明
- 此接口暂不支持与其他TBE DSL计算接口混合使用。
- rhs第一维最大为7934,超过7934无法处理。
- 当rhs第一维较大时(如5000+),可能会因操作系统OS栈溢出导致core dump,可以使用ulimit -s指令把栈空间增大,例如由8192增加到81920。
支持的型号
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas 推理系列产品
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
from tbe import tvm from tbe import dsl input_dtype = "float16" dataA = tvm.placeholder((6,1024), name="dataA", dtype=input_dtype) dataB = tvm.placeholder((5,1024), name="dataB", dtype=input_dtype) inplace_ids = [1,1,4,2,2] res = dsl.inplace_update(dataA, inplace_ids, dataB) res.shape = (6,1024) # res[0] = dataA[0] # res[1] = dataB[1] # res[2] = dataB[4] # res[3] = dataA[3] # res[4] = dataB[2] # res[5] = dataA[5]
父主题: Math计算接口