定义动态多输入算子(AddN)
某些算子的输入个数不固定,为动态多输入算子,例如AddN,下面介绍如何定义这类算子。
AddN算子原型定义:
1 2 3 4 5 | REG_OP(AddN) .DYNAMIC_INPUT(x, TensorType::NumberType(), DT_VARIANT) .OUTPUT(y, TensorType::NumberType(), DT_VARIANT) .REQUIRED_ATTR(N, Int) .OP_END_FACTORY_REG(AddN) |
通过AddN算子原型定义可以看到,该算子为动态多输入算子,我们通过“create_dynamic_input_输入名称”创建动态输入,通过“set_dynamic_input_输入名称”设置动态输入。
1 2 3 4 5 6 | auto data = op::Data().set_attr_index(0); auto addn = op::AddN("addn") .create_dynamic_input_x(2) //创建动态输入x,包括2个输入,并且把这两个输入所为算子最后的输入 .set_dynamic_input_x(0,data) //设置第1个输入,0表示输入索引,默认从0开始,data表示输入value .set_dynamic_input_x(1,data) //设置第2个输入,1表示输入索引,默认从0开始,data表示输入value .set_attr_N(2); //设置属性N的值为2,表示该算子有2个输入 |
也可以通过“create_dynamic_input_byindex_输入名称”设置动态输入,但是和“create_dynamic_input_输入名称”不能同时使用,两者的区别是:“create_dynamic_input_输入名称”默认把创建的动态输入作为算子最后的输入,而“create_dynamic_input_byindex_输入名称”可以指定动态输入的索引位置,例如:
1 2 3 4 5 6 | auto concatv2 = op::ConcatV2("concatv2") .create_dynamic_input_byindex_x(2,0) //创建动态输入x,包括2个输入,并且把这两个输入插入到索引0和索引1的位置,0表示动态输入索引的起始位置 .set_dynamic_input_x(0,data1) //设置第1个输入,0表示输入索引,默认从0开始,data1表示输入value .set_dynamic_input_x(1,data2) //设置第2个输入,1表示输入索引,默认从0开始,data2表示输入value .set_input_concat_dim(data3) //设置第3个输入,data3表示输入value .set_attr_N(2); //设置属性N的值为2,表示该算子有2个输入 |
父主题: 各类算子表达