设置融合结果
通过GenerateFusionResult设置融合结果, 包含融合算子的名字、类型、输入、 输出、描述。最终融合结果保存在fusion_rlt返回,返回结果的类型为FusionScopesResult类。
图1 融合结果示意图

- 通过InsertInputs设置融合算子输入,例如:
fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});
- 第一个参数表示融合算子的输入,即scope内部小算子name(除去scope名称的部分)。
- 第二个参数表示输入index的映射,是一个vector类型。vector的index表示scope内部小算子的输入index,具体值表示融合算子的输入index,如果融合算子没有使用这个index,则使用占位符kFusionDisableIndex表示。
表1 示例说明 序号
代码示例
示例1
fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});
表示将transpose的第0个输入作为融合算子的第0个输入, transpose的第1个输入不使用,使用占位符kFusionDisableIndex表示。
示例2
fusion_rlt->InsertInputs("transpose", {1, kFusionDisableIndex});
表示将transpose的第0个输入作为融合算子的第1个输入, transpose的第1个输入不使用,使用占位符kFusionDisableIndex表示。
示例3
fusion_rlt->InsertInputs("transpose", {kFusionDisableIndex, 0});
表示将transpose的第1个输入作为融合算子的第0个输入, transpose的第0个输入不使用,使用占位符kFusionDisableIndex表示。
- 通过InsertOutputs设置融合算子输出,使用注意点和设置融合算子输入类似。例如:
// 设置融合算子输出, 将transpose_1的第0个输出作为融合算子的输出 fusion_rlt->InsertOutputs("transpose_1", {0});
- 通过SetType设置融合算子的结果类型,例如:
// 设置融合算子的类型 fusion_rlt->SetType(kScopeType);
需要注意的是,此处传入的类型需要和融合算子插件注册的OriginOpType保持一致:
REGISTER_CUSTOM_OP("DecodeBboxV2") .FrameworkType(TENSORFLOW) // 原始框架为Tensorflow .OriginOpType("DecodeBboxV2FusionOp") // 算子在原始框架中的类型,和GenerateFusionResult的SetType的内容保持一致 .FusionParseParamsFn(DecodeBboxV2ParseParams) // 用来注册解析融合算子属性的函数 .ImplyType(ImplyType::TVM); // 指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子
如果识别出scope不满足条件,则不融合,可以设置type为kScopeInvalidType然后返回。
if (scopes.size() != 1) { fusion_rlt->SetType(kScopeInvalidType); return; }
- 通过SetName设置融合算子名称。需要注意的是,为保持融合算子名称全局唯一,建议尽量不要自行命名,可以根据scope的名称设置,例如:
// 设置融合算子的名称 AscendString scope_name; Status ret = scopes[0]->Name(scope_name); std::string scope_name_str; if (scope_name.GetString() != nullptr) { scope_name_str = scope_name.GetString(); } fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str());
- 通过SetDescription设置融合算子描述,例如:
// 设置融合算子的描述 fusion_rlt->SetDescription("");
完整代码示例为:
void CustomScopeDecodeBboxV2Pass::GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) { if (fusion_rlt == nullptr) { return; } if (scopes.size() != 1) { fusion_rlt->SetType(kScopeInvalidType); // 如果识别出scope不满足条件,可以设置type为kScopeInvalidType然后返回 return; } // 设置融合算子输入,将transpose的第0个输入作为融合算子的第0个输入, transpose的第一个输入不使用 fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex}); // 设置融合算子输入,将get_center_coordinates_and_sizes/transpose的第0个输入作为融合算子的第1个输入,get_center_coordinates_and_sizes/transpose的第一个输入不使用 fusion_rlt->InsertInputs("get_center_coordinates_and_sizes/transpose", {1, kFusionDisableIndex}); // 设置融合算子输出, 将transpose_1的第0个输出作为融合算子的输出 fusion_rlt->InsertOutputs("transpose_1", {0}); // 设置融合算子类型 fusion_rlt->SetType(kScopeType); // 设置融合算子名称 AscendString scope_name; Status ret = scopes[0]->Name(scope_name); if (ret != SUCCESS) { return; } std::string scope_name_str; if (scope_name.GetString() != nullptr) { scope_name_str = scope_name.GetString(); } fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str()); // 设置融合算子描述 fusion_rlt->SetDescription(""); OP_LOGI(kOpType, "Set fusion result successfully."); return; }
父主题: Scope融合规则实现