定义动态多输出算子(Split)
某些算子的输出个数不固定,为动态多输出算子,例如Split,下面介绍如何定义这类算子。
Split算子原型定义:
1 2 3 4 5 6 7 8 9 10 11 12 |
REG_OP(Split) .INPUT(split_dim, TensorType({DT_INT32})) .INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, DT_BF16, DT_BOOL})) .DYNAMIC_OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, DT_BF16, DT_BOOL})) .REQUIRED_ATTR(num_split, Int) .OP_END_FACTORY_REG(Split) |
Split算子原型定义可以看到,该算子为动态多输出算子,我们通过“create_dynamic_output_输出名称”创建动态输出。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
auto split = op::Split("split") .set_input_x(data) //如何构造Data算子,请参考 .set_input_split_dim(const) //如何构造Const算子,请参考 .set_attr_num_split(2) .create_dynamic_output_y(2); //创建split算子的动态输出y,包括2个输出 auto addn = op::AddN("addn") .create_dynamic_input_x(1) //创建动态输入x,包括1个输入 .set_dynamic_input_x(0, split, "y0") //设置addn算子的第1个输入,split表示输入算子,“y0”表示split算子的输出名称, y0为第一个输出 .set_attr_N(1); //设置属性N的值为1,表示该算子有1个输入 auto softplus = op::Softplus("softplus") .set_input_x(split, "y1"); //设置softplus算子的输入,split表示输入算子,y1表示split算子的第2个输出 |
父主题: 各类算子表达