设置最终匹配规则
设置最终匹配规则包括设置基本Scope匹配规则和Scope并列匹配规则。
介绍
符合定义融合规则要求的Scope并不一定是最终的融合目标,比如需要进行并列Scope、嵌套Scope等的筛选和判断,如图1所示,此时通过LastMatchScopesAndOPs设置最终匹配规则,对匹配到的Scope进行进一步筛选,将符合的Scope保存到ScopesResult中。
基本Scope匹配规则
本例中,期望融合的目标Scope的类型为kScopeTypeDecodeBboxV2,不需要设置更高阶的规则,直接找到scope保存到results返回即可,返回结果的类型为ScopesResult类。
Status DecodeBboxV2ScopeFusionPass::LastMatchScopesAndOPs(shared_ptr <ScopeGraph> &scope_graph,
std::vector <ScopesResult> &results) {
OP_LOGI(kOpType, "LastMatchScopesAndOPs start.");
if (scope_graph == nullptr) {
OP_LOGE(kOpType, "Input params is nullptr.");
return FAILED;
}
const ScopeTree *scope_tree = scope_graph->GetScopeTree();
if (scope_tree == nullptr) {
OP_LOGE(kOpType, "Scope tree is nullptr.");
return FAILED;
}
const std::vector<Scope *> &scopes = scope_tree->GetAllScopes();
for (auto &scope : scopes) {
// Class ScopeTree guarantees scope is not empty.
AscendString op_subtype;
Status ret = scope->SubType(op_subtype);
if (ret != SUCCESS) {
return FAILED;
}
AscendString op_name;
ret = scope->Name(op_name);
if (ret != SUCCESS) {
return FAILED;
}
if (op_subtype == kScopeTypeDecodeBboxV2) {
OP_LOGI(kOpType, "DecodeBbox LastMatchScopesAndOPs match scope %s.", op_name.GetString());
ScopesResult result;
std::vector < Scope * > result_scopes;
result_scopes.push_back(scope);
result.SetScopes(result_scopes);
results.push_back(result);
}
}
return (!(results.empty())) ? SUCCESS : FAILED;
}
Scope并列匹配规则
用户还可以定义更复杂一些的Scope并列匹配规则,例如下面示例中,首先找到类型为kScopeTypeBatchnorm和kScopeTypeMoments的Scope,然后判断如果两个Scope在网络中的同一层,则进行融合。
/**
* @brief LastMatch for multiple scopes
*/
Status ScopeLayerNormPass::LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph>& scope_graph,
std::vector<ScopesResult>& results) {
if (scope_graph == nullptr) {
OP_LOGE(kOpType, "Input params is nullptr.");
return domi::PARAM_INVALID;
}
const ScopeTree* scope_tree = scope_graph->GetScopeTree();
if (scope_tree == nullptr) {
OP_LOGE(kOpType, "Scope tree is nullptr.");
return domi::PARAM_INVALID;
}
const std::vector<Scope*>& scopes = scope_tree->GetAllScopes();
std::vector<Scope*> fusion_scopes_bn;
std::vector<Scope*> fusion_scopes_m;
for (auto& scope : scopes) {
// Class ScopeTree guarantees scope is not empty.
AscendString op_subtype;
Status ret = scope->SubType(op_subtype);
if (ret != SUCCESS) {
return FAILED;
}
if (op_subtype == kScopeTypeBatchnorm) {
fusion_scopes_bn.push_back(scope);
} else if (op_subtype == kScopeTypeMoments) {
fusion_scopes_m.push_back(scope);
}
}
if (fusion_scopes_bn.size() == fusion_scopes_m.size()) {
// the two scope batchnorm and moments in the same layernorm
for (size_t i = 0; i < fusion_scopes_bn.size(); i++) {
auto scope_bn = fusion_scopes_bn[i];
for (size_t j = 0; j < fusion_scopes_m.size(); j++) {
auto scope_m = fusion_scopes_m[j];
AscendString scope_bn_name;
Status ret = scope_bn->Name(scope_bn_name);
if (ret != SUCCESS) {
return FAILED;
}
AscendString scope_m_name;
ret = scope_m->Name(scope_m_name);
if (ret != SUCCESS) {
return FAILED;
}
std::string scope_m_name_str;
std::string scope_bn_name_str;
if (scope_m_name.GetString() != nullptr) {
scope_m_name_str = scope_m_name.GetString();
}
if (scope_bn_name.GetString() != nullptr) {
scope_bn_name_str = scope_bn_name.GetString();
}
int pos_bn = scope_bn_name_str .find("batchnorm");
int pos_m = scope_m_name_str .find("moments");
int is_biggan_bn = scope_bn_name_str .find("resblock");
int is_biggan_m = scope_m_name_str .find("resblock");
if (is_biggan_bn != -1 || is_biggan_m != -1) {
return FAILED;
}
if (pos_bn != -1 && pos_m != -1 && scope_bn_name_str.substr(0, pos_bn) == scope_m_name_str.substr(0, pos_m)) {
// scope result
ScopesResult result;
std::vector<Scope*> result_scopes;
result_scopes.push_back(scope_bn);
result_scopes.push_back(scope_m);
result.SetScopes(result_scopes);
results.push_back(result);
OP_LOGI(kOpType, "scope:%s, and scope:%s is connect.", scope_bn_name.GetString(), scope_m_name.GetString());
break;
}
}
}
}
return (!(results.empty())) ? SUCCESS : FAILED;
}
父主题: Scope融合规则实现
