下载
中文
注册

定义简单计算算子(SoftmaxV2)

下面以一个简单的SoftmaxV2为例,介绍如何进行算子定义。

SoftmaxV2算子原型定义:

1
2
3
4
5
6
REG_OP(SoftmaxV2)
    .INPUT(x, TensorType({ DT_DOUBLE, DT_FLOAT16, DT_BF16, DT_FLOAT }))
    .OUTPUT(y, TensorType({ DT_DOUBLE, DT_FLOAT16, DT_BF16, DT_FLOAT }))
    .ATTR(axes, ListInt, {-1})
    .ATTR(half_to_float, Bool, false)
    .OP_END_FACTORY_REG(SoftmaxV2)

从SoftmaxV2算子原型可以看到,SoftmaxV2算子有一个必选输入,输入名称为x。创建SoftmaxV2算子实例:

1
2
auto softmax = op::SoftmaxV2("Softmax")     //创建算子实例,传算子名称(例如Softmax)作为入参
    .set_input_x(matmul2);                //设置算子输入为matmul2