Follow
函数功能
用于指定当前输入/输出的datatype/format/shape信息与之前定义过的某个输入一致。
函数原型
OpParamDef &Follow(const char *paramName);
OpParamDef &Follow(const char *paramName, FollowType ftype);
参数说明
参数 |
输入/输出 |
说明 |
---|---|---|
paramName |
输入 |
之前定义过的输入名。 |
ftype |
输入 |
ftype类型为枚举类FollowType,表示Follow的模式,取值如下:
|
返回值说明
OpParamDef算子定义,OpParamDef请参考OpParamDef。
约束说明
- 所有Follow的数据源头仅支持为Input。
- 针对shape推导,仅支持输出参数Follow输入参数,不支持输入参数Follow另一个输入参数。
- 可以支持链式Follow,例如C Follow B,B Follow A,但此时Follow的模式不可中途变更(ftype需要保持一致)。
- 使用Follow接口通常比InferShape函数逻辑更加简单,能用Follow表达的逻辑,建议使用Follow接口,则无需再编写注册InferShape函数。
- InferShape推导函数和使用Follow接口去Follow shape不能混用,即不支持部分输出采用Infershape推导、部分输出采用Follow推导的情况。若用户同时使用了InferShape函数和Follow接口,以用户的InferShape函数为准,需要保证在InferShape函数中能够推导出所有的输出shape。
- datatype/format同时支持Follow输入的参数类型为DataTypeList/FormatList ,调用Follow后,当前输入/输出的datatype/format与paramName组合后的datatype/format一致。
调用示例
输出“y1”Follow输入“x1”场景,此时“y1”的datatype、format以及shape都将会和“x1”保持一致。
this->Input("x1") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x2") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("y1") .ParamType(REQUIRED) .Follow("x1") .OutputShapeDependOnCompute();
父主题: OpParamDef