下载
中文
注册

定义动态多输出算子(Split)

某些算子的输出个数不固定,为动态多输出算子,例如Split,下面介绍如何定义这类算子。

Split算子原型定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
REG_OP(Split)
    .INPUT(split_dim, TensorType({DT_INT32}))
    .INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT,  DT_FLOAT16, DT_INT16,
                          DT_INT32,      DT_INT64,     DT_INT8,   DT_QINT16, DT_QINT32,  DT_QINT8,
                          DT_QUINT16,    DT_QUINT8,    DT_UINT16, DT_UINT32, DT_UINT64,  DT_UINT8,
                          DT_BF16,       DT_BOOL}))
    .DYNAMIC_OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT,  DT_FLOAT16, DT_INT16,
                                   DT_INT32,      DT_INT64,     DT_INT8,   DT_QINT16, DT_QINT32,  DT_QINT8,
                                   DT_QUINT16,    DT_QUINT8,    DT_UINT16, DT_UINT32, DT_UINT64,  DT_UINT8,
                                   DT_BF16,       DT_BOOL}))
    .REQUIRED_ATTR(num_split, Int)
    .OP_END_FACTORY_REG(Split)

Split算子原型定义可以看到,该算子为动态多输出算子,我们通过“create_dynamic_output_输出名称”创建动态输出。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
auto split = op::Split("split")
  .set_input_x(data)                         //如何构造Data算子,请参考
  .set_input_split_dim(const)                //如何构造Const算子,请参考
  .set_attr_num_split(2)
  .create_dynamic_output_y(2);           //创建split算子的动态输出y,包括2个输出

auto addn = op::AddN("addn")
  .create_dynamic_input_x(1)                 //创建动态输入x,包括1个输入
  .set_dynamic_input_x(0, split, "y0")       //设置addn算子的第1个输入,split表示输入算子,“y0”表示split算子的输出名称, y0为第一个输出
  .set_attr_N(1);                            //设置属性N的值为1,表示该算子有1个输入

auto softplus = op::Softplus("softplus")
  .set_input_x(split, "y1");          //设置softplus算子的输入,split表示输入算子,y1表示split算子的第2个输出