下载
中文
注册

定义动态多输入算子(AddN)

某些算子的输入个数不固定,为动态多输入算子,例如AddN,下面介绍如何定义这类算子。

AddN算子原型定义:

1
2
3
4
5
REG_OP(AddN)
    .DYNAMIC_INPUT(x, TensorType::NumberType(), DT_VARIANT)
    .OUTPUT(y, TensorType::NumberType(), DT_VARIANT)
    .REQUIRED_ATTR(N, Int)
    .OP_END_FACTORY_REG(AddN)

通过AddN算子原型定义可以看到,该算子为动态多输入算子,我们通过“create_dynamic_input_输入名称”创建动态输入,通过“set_dynamic_input_输入名称”设置动态输入。

1
2
3
4
5
6
auto data = op::Data().set_attr_index(0);
auto addn = op::AddN("addn")
 .create_dynamic_input_x(2)    //创建动态输入x,包括2个输入,并且把这两个输入所为算子最后的输入
 .set_dynamic_input_x(0,data)     //设置第1个输入,0表示输入索引,默认从0开始,data表示输入value
 .set_dynamic_input_x(1,data)     //设置第2个输入,1表示输入索引,默认从0开始,data表示输入value
 .set_attr_N(2);                  //设置属性N的值为2,表示该算子有2个输入

也可以通过“create_dynamic_input_byindex_输入名称”设置动态输入,但是和“create_dynamic_input_输入名称”不能同时使用,两者的区别是:“create_dynamic_input_输入名称”默认把创建的动态输入作为算子最后的输入,而“create_dynamic_input_byindex_输入名称”可以指定动态输入的索引位置,例如:

1
2
3
4
5
6
auto concatv2 = op::ConcatV2("concatv2")
 .create_dynamic_input_byindex_x(2,0)    //创建动态输入x,包括2个输入,并且把这两个输入插入到索引0和索引1的位置,0表示动态输入索引的起始位置
 .set_dynamic_input_x(0,data1)     //设置第1个输入,0表示输入索引,默认从0开始,data1表示输入value
 .set_dynamic_input_x(1,data2)     //设置第2个输入,1表示输入索引,默认从0开始,data2表示输入value
 .set_input_concat_dim(data3)     //设置第3个输入,data3表示输入value
 .set_attr_N(2);                  //设置属性N的值为2,表示该算子有2个输入