下载
中文
注册

inplace_sub

功能说明

inplace_sub实现lhs指定行和rhs的减法计算。

例如:

res = lhs
res[ids,:] -= rhs
return res

函数原型

inplace_sub(lhs, inplace_ids, rhs)

参数说明

  • lhs:输入左tensor。
  • inplace_ids:为int、list类型,数值大于等于0且不能大于lhs第一维的整型。其长度必须和输入rhs的第一维的长度相同。
  • rhs:输入右tensor或scalar,除了第一维外,其他维度必须与lhs相同。当inplace_ids为int型单个数时,rhs比lhs差一个维度。例如:lhs为(10,1024),inplace_ids为[5],rhs为(1,1024);lhs为(10,1024),inplace_ids为5,rhs为(1024,)。
  • 支持的数据类型:

    Atlas 200/300/500 推理产品:支持的数据类型有float16、float32、int32

    Atlas 训练系列产品:支持的数据类型有float16、float32、int32

    Atlas 推理系列产品:支持的数据类型有float16、float32、int32

    Atlas 200/500 A2推理产品:支持的数据类型有float16、float32、int32

    Atlas A2训练系列产品/Atlas 800I A2推理产品:支持的数据类型有float16、float32、int32

返回值:

res_tensor:表示计算后的tensor。

约束说明

  • 此接口暂不支持与其他TBE DSL计算接口混合使用。
  • rhs第一维最大为7934,超过7934无法处理。
  • 当rhs第一维较大时(如5000+),可能会因操作系统OS栈溢出导致core dump,可以使用ulimit -s指令把栈空间增大,例如由8192增加到81920。

支持的型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

Atlas 推理系列产品

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

from tbe import tvm
from tbe import dsl
input_dtype = "float16"
dataA = tvm.placeholder((6,1024), name="dataA", dtype=input_dtype)
dataB = tvm.placeholder((5,1024), name="dataB", dtype=input_dtype)
inplace_ids = [1,1,4,2,2]
res = dsl.inplace_sub(dataA, inplace_ids, dataB)
res.shape = (6,1024)
# res[0] = dataA[0]
# res[1] = dataA[1] - dataB[0] - dataB[1]
# res[2] = dataA[2] - dataB[3] - dataB[4]
# res[3] = dataA[3]
# res[4] = dataA[4] - dataB[2]
# res[5] = dataA[5]