下载
中文
注册

SortOperation

功能

后处理计算功能。实现输入tensor在最后一维上降序排列,并保留最大的num个元素,输出排序后的tensor及各元素对应的索引。

算子上下文

图1 SortOperation算子上下文

定义

struct SortParam {
    SVector<int32_t> num;
};

参数列表

成员名称

类型

默认值

描述

num

SVector<int32_t>

-

排序后保留的最大的元素的数量。

num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。

输入

参数

维度

数据类型

格式

描述

x

[dim_0,dim_1,...,dim_n]

float16/bf16

ND

最后一维应至少有num个元素。

输出

参数

维度

数据类型

格式

描述

output

[dim_0,dim_1,...,num]

float16/bf16

ND

最后一维排序后,最大的num个元素。

indices

[dim_0,dim_1,...,num]

int32

ND

最大的num个元素对应的原索引。

规格约束

num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。

接口调用示例

输入:

num = [1]
x = [[3.0, 4.0],
       [5.0, 6.0]]

输出:

output = [[4.0],
                 [6.0]]
indices = [[1.0],
                 [1.0]]