下载
中文
注册

设置最终匹配规则

设置最终匹配规则包括设置基本Scope匹配规则和Scope并列匹配规则。

介绍

符合定义融合规则要求的Scope并不一定是最终的融合目标,比如需要进行并列Scope、嵌套Scope等的筛选和判断,如图1所示,此时通过LastMatchScopesAndOPs设置最终匹配规则,对匹配到的Scope进行进一步筛选,将符合的Scope保存到ScopesResult中。

图1 Scope融合支持的场景

基本Scope匹配规则

本例中,期望融合的目标Scope的类型为kScopeTypeDecodeBboxV2,不需要设置更高阶的规则,直接找到scope保存到results返回即可,返回结果的类型为ScopesResult类

    Status DecodeBboxV2ScopeFusionPass::LastMatchScopesAndOPs(shared_ptr <ScopeGraph> &scope_graph,
                                                              std::vector <ScopesResult> &results) {
        OP_LOGI(kOpType, "LastMatchScopesAndOPs start.");
        if (scope_graph == nullptr) {
            OP_LOGE(kOpType, "Input params is nullptr.");
            return FAILED;
        }
        const ScopeTree *scope_tree = scope_graph->GetScopeTree();
        if (scope_tree == nullptr) {
            OP_LOGE(kOpType, "Scope tree is nullptr.");
            return FAILED;
        }
        const std::vector<Scope *> &scopes = scope_tree->GetAllScopes();

        for (auto &scope : scopes) {
            // Class ScopeTree guarantees scope is not empty.
            AscendString op_subtype;
            Status ret = scope->SubType(op_subtype);
            if (ret != SUCCESS) {
                return FAILED;
            }
            AscendString op_name;
            ret = scope->Name(op_name);
            if (ret != SUCCESS) {
                return FAILED;
            }
            if (op_subtype == kScopeTypeDecodeBboxV2) {
                OP_LOGI(kOpType, "DecodeBbox LastMatchScopesAndOPs match scope %s.", op_name.GetString());
                ScopesResult result;
                std::vector < Scope * > result_scopes;
                result_scopes.push_back(scope);
                result.SetScopes(result_scopes);
                results.push_back(result);
            }
        }
        return (!(results.empty())) ? SUCCESS : FAILED;
    }

Scope并列匹配规则

用户还可以定义更复杂一些的Scope并列匹配规则,例如下面示例中,首先找到类型为kScopeTypeBatchnorm和kScopeTypeMoments的Scope,然后判断如果两个Scope在网络中的同一层,则进行融合。

/**
 * @brief LastMatch for multiple scopes
 */
Status ScopeLayerNormPass::LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph>& scope_graph,
                                                 std::vector<ScopesResult>& results) {
  if (scope_graph == nullptr) {
    OP_LOGE(kOpType, "Input params is nullptr.");
    return domi::PARAM_INVALID;
  }
  const ScopeTree* scope_tree = scope_graph->GetScopeTree();
  if (scope_tree == nullptr) {
    OP_LOGE(kOpType, "Scope tree is nullptr.");
    return domi::PARAM_INVALID;
  }
  const std::vector<Scope*>& scopes = scope_tree->GetAllScopes();
  std::vector<Scope*> fusion_scopes_bn;
  std::vector<Scope*> fusion_scopes_m;
  for (auto& scope : scopes) {
    // Class ScopeTree guarantees scope is not empty.
    AscendString op_subtype;
    Status ret = scope->SubType(op_subtype);
    if (ret != SUCCESS) {
        return FAILED;
    }
    if (op_subtype == kScopeTypeBatchnorm) {
      fusion_scopes_bn.push_back(scope);
    } else if (op_subtype == kScopeTypeMoments) {
      fusion_scopes_m.push_back(scope);
    }
  }

  if (fusion_scopes_bn.size() == fusion_scopes_m.size()) {
    // the two scope batchnorm and moments in the same layernorm
    for (size_t i = 0; i < fusion_scopes_bn.size(); i++) {
      auto scope_bn = fusion_scopes_bn[i];
      for (size_t j = 0; j < fusion_scopes_m.size(); j++) {
        auto scope_m = fusion_scopes_m[j];
        AscendString scope_bn_name; 
        Status ret = scope_bn->Name(scope_bn_name);
        if (ret != SUCCESS) {
            return FAILED;
        }
        AscendString  scope_m_name; 
        ret = scope_m->Name(scope_m_name);
        if (ret != SUCCESS) {
            return FAILED;
        }
        std::string scope_m_name_str;
        std::string scope_bn_name_str;
        if (scope_m_name.GetString() != nullptr) {
            scope_m_name_str = scope_m_name.GetString();
        }
        if (scope_bn_name.GetString() != nullptr) {
            scope_bn_name_str = scope_bn_name.GetString();
        }
        int pos_bn = scope_bn_name_str .find("batchnorm");
        int pos_m = scope_m_name_str .find("moments");
        int is_biggan_bn = scope_bn_name_str .find("resblock");
        int is_biggan_m = scope_m_name_str .find("resblock");
        if (is_biggan_bn != -1 || is_biggan_m != -1) {
          return FAILED;
        }
        if (pos_bn != -1 && pos_m != -1 && scope_bn_name_str.substr(0, pos_bn) == scope_m_name_str.substr(0, pos_m)) {
          // scope result
          ScopesResult result;
          std::vector<Scope*> result_scopes;
          result_scopes.push_back(scope_bn);
          result_scopes.push_back(scope_m);
          result.SetScopes(result_scopes);
          results.push_back(result);
          OP_LOGI(kOpType, "scope:%s, and scope:%s is connect.", scope_bn_name.GetString(), scope_m_name.GetString());
          break;
        }
      }
    }
  }
  return (!(results.empty())) ? SUCCESS : FAILED;
}