unsorted_segment_sum

此接口后续版本会废弃,新开发算子请不要使用此接口。

功能说明

使用数组segment_ids对tensor进行分段求和。假设输入为data,输出为output,则output[i] = sum(data[j...]),其中“j...”是一个数组,“j...”中的元素j满足:segment_ids[j] == i。

如果某个下标i在segment_ids中没有出现,则output[i] = init_value。比如下图中,1在segment_ids中没有出现,则output[1] = 0。

如果segment_ids中某个值为负数,则对应位置的data中的值则丢弃。比如下图中,segment_ids[3] = -1,则data[3]的值被丢弃,不参与计算。

您可以在“te/lang/cce/te_compute/segment_compute.py”查看接口定义。

约束说明

Atlas 200/300/500 推理产品:支持的数据类型有float16、float32、int32

Atlas 训练系列产品:支持的数据类型有float16、float32、int32

函数原型

te.lang.cce.unsorted_segment_sum(tensor, segment_ids, num_segments, init_value=0)

参数说明

返回值

res_tensor:表示计算后的tensor。

支持的芯片型号

Atlas 200/300/500 推理产品

Atlas 训练系列产品

调用示例

import tvm
import te.lang.cce
shape = (5,1024)
input_dtype = "float16"
data = tvm.placeholder(shape, name="data1", dtype=input_dtype)
segment_ids = [1,1,4,5,5]
num_segments = 6
res = te.lang.cce.unsorted_segment_sum(data, segment_ids, num_segments)
res.shape = (6,1024)
# res[0] = 0
# res[1] = data[0] + data[1]
# res[2] = 0
# res[3] = 0
# res[4] = data[2]
# res[5] = data[3] + data[4]