定义简单计算算子(SoftmaxV2)
下面以一个简单的SoftmaxV2为例,介绍如何进行算子定义。
SoftmaxV2算子原型定义:
1 2 3 4 5 6 | REG_OP(SoftmaxV2) .INPUT(x, TensorType({ DT_DOUBLE, DT_FLOAT16, DT_BF16, DT_FLOAT })) .OUTPUT(y, TensorType({ DT_DOUBLE, DT_FLOAT16, DT_BF16, DT_FLOAT })) .ATTR(axes, ListInt, {-1}) .ATTR(half_to_float, Bool, false) .OP_END_FACTORY_REG(SoftmaxV2) |
从SoftmaxV2算子原型可以看到,SoftmaxV2算子有一个必选输入,输入名称为x。创建SoftmaxV2算子实例:
1 2 | auto softmax = op::SoftmaxV2("Softmax") //创建算子实例,传算子名称(例如Softmax)作为入参 .set_input_x(matmul2); //设置算子输入为matmul2 |
父主题: 各类算子表达