设置最终匹配规则
设置最终匹配规则包括设置基本Scope匹配规则和Scope并列匹配规则。
介绍
符合定义融合规则要求的Scope并不一定是最终的融合目标,比如需要进行并列Scope、嵌套Scope等的筛选和判断,如图1所示,此时通过LastMatchScopesAndOPs设置最终匹配规则,对匹配到的Scope进行进一步筛选,将符合的Scope保存到ScopesResult中。
基本Scope匹配规则
本例中,期望融合的目标Scope的类型为kScopeTypeDecodeBboxV2,不需要设置更高阶的规则,直接找到scope保存到results返回即可,返回结果的类型为ScopesResult类。
Status DecodeBboxV2ScopeFusionPass::LastMatchScopesAndOPs(shared_ptr <ScopeGraph> &scope_graph, std::vector <ScopesResult> &results) { OP_LOGI(kOpType, "LastMatchScopesAndOPs start."); if (scope_graph == nullptr) { OP_LOGE(kOpType, "Input params is nullptr."); return FAILED; } const ScopeTree *scope_tree = scope_graph->GetScopeTree(); if (scope_tree == nullptr) { OP_LOGE(kOpType, "Scope tree is nullptr."); return FAILED; } const std::vector<Scope *> &scopes = scope_tree->GetAllScopes(); for (auto &scope : scopes) { // Class ScopeTree guarantees scope is not empty. AscendString op_subtype; Status ret = scope->SubType(op_subtype); if (ret != SUCCESS) { return FAILED; } AscendString op_name; ret = scope->Name(op_name); if (ret != SUCCESS) { return FAILED; } if (op_subtype == kScopeTypeDecodeBboxV2) { OP_LOGI(kOpType, "DecodeBbox LastMatchScopesAndOPs match scope %s.", op_name.GetString()); ScopesResult result; std::vector < Scope * > result_scopes; result_scopes.push_back(scope); result.SetScopes(result_scopes); results.push_back(result); } } return (!(results.empty())) ? SUCCESS : FAILED; }
Scope并列匹配规则
用户还可以定义更复杂一些的Scope并列匹配规则,例如下面示例中,首先找到类型为kScopeTypeBatchnorm和kScopeTypeMoments的Scope,然后判断如果两个Scope在网络中的同一层,则进行融合。
/** * @brief LastMatch for multiple scopes */ Status ScopeLayerNormPass::LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph>& scope_graph, std::vector<ScopesResult>& results) { if (scope_graph == nullptr) { OP_LOGE(kOpType, "Input params is nullptr."); return domi::PARAM_INVALID; } const ScopeTree* scope_tree = scope_graph->GetScopeTree(); if (scope_tree == nullptr) { OP_LOGE(kOpType, "Scope tree is nullptr."); return domi::PARAM_INVALID; } const std::vector<Scope*>& scopes = scope_tree->GetAllScopes(); std::vector<Scope*> fusion_scopes_bn; std::vector<Scope*> fusion_scopes_m; for (auto& scope : scopes) { // Class ScopeTree guarantees scope is not empty. AscendString op_subtype; Status ret = scope->SubType(op_subtype); if (ret != SUCCESS) { return FAILED; } if (op_subtype == kScopeTypeBatchnorm) { fusion_scopes_bn.push_back(scope); } else if (op_subtype == kScopeTypeMoments) { fusion_scopes_m.push_back(scope); } } if (fusion_scopes_bn.size() == fusion_scopes_m.size()) { // the two scope batchnorm and moments in the same layernorm for (size_t i = 0; i < fusion_scopes_bn.size(); i++) { auto scope_bn = fusion_scopes_bn[i]; for (size_t j = 0; j < fusion_scopes_m.size(); j++) { auto scope_m = fusion_scopes_m[j]; AscendString scope_bn_name; Status ret = scope_bn->Name(scope_bn_name); if (ret != SUCCESS) { return FAILED; } AscendString scope_m_name; ret = scope_m->Name(scope_m_name); if (ret != SUCCESS) { return FAILED; } std::string scope_m_name_str; std::string scope_bn_name_str; if (scope_m_name.GetString() != nullptr) { scope_m_name_str = scope_m_name.GetString(); } if (scope_bn_name.GetString() != nullptr) { scope_bn_name_str = scope_bn_name.GetString(); } int pos_bn = scope_bn_name_str .find("batchnorm"); int pos_m = scope_m_name_str .find("moments"); int is_biggan_bn = scope_bn_name_str .find("resblock"); int is_biggan_m = scope_m_name_str .find("resblock"); if (is_biggan_bn != -1 || is_biggan_m != -1) { return FAILED; } if (pos_bn != -1 && pos_m != -1 && scope_bn_name_str.substr(0, pos_bn) == scope_m_name_str.substr(0, pos_m)) { // scope result ScopesResult result; std::vector<Scope*> result_scopes; result_scopes.push_back(scope_bn); result_scopes.push_back(scope_m); result.SetScopes(result_scopes); results.push_back(result); OP_LOGI(kOpType, "scope:%s, and scope:%s is connect.", scope_bn_name.GetString(), scope_m_name.GetString()); break; } } } } return (!(results.empty())) ? SUCCESS : FAILED; }
父主题: Scope融合规则实现