文档
注册

设置融合结果

我们通过GenerateFusionResult设置融合结果, 包含融合算子的名字、类型、输入、 输出、描述。最终融合结果保存在fusion_rlt返回,返回结果的类型为FusionScopesResult类

图1 融合结果示意图
  1. 通过InsertInputs设置融合算子输入,例如:
      fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});
    • 第一个参数表示融合算子的输入,即scope内部小算子name(除去scope名称的部分)。
    • 第二个参数表示输入index的映射,是一个vector类型。vector的index表示scope内部小算子的输入index,具体值表示融合算子的输入index,如果融合算子没有使用这个index,则使用占位符kFusionDisableIndex表示。
      表1 示例说明

      序号

      代码示例

      示例1

      fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});

      表示将transpose的第0个输入作为融合算子的第0个输入, transpose的第1个输入不使用,使用占位符kFusionDisableIndex表示。

      示例2

      fusion_rlt->InsertInputs("transpose", {1, kFusionDisableIndex});

      表示将transpose的第0个输入作为融合算子的第1个输入, transpose的第1个输入不使用,使用占位符kFusionDisableIndex表示。

      示例3

      fusion_rlt->InsertInputs("transpose", {kFusionDisableIndex, 0});

      表示将transpose的第1个输入作为融合算子的第0个输入, transpose的第0个输入不使用,使用占位符kFusionDisableIndex表示。

  2. 通过InsertOutputs设置融合算子输出,使用注意点和设置融合算子输入类似。例如:
      // 设置融合算子输出, 将transpose_1的第0个输出作为融合算子的输出
      fusion_rlt->InsertOutputs("transpose_1", {0});
  3. 通过SetType设置融合算子的结果类型,例如:
      // 设置融合算子的类型
      fusion_rlt->SetType(kScopeType);

    需要注意的是,此处传入的类型需要和融合算子插件注册的OriginOpType保持一致:

    REGISTER_CUSTOM_OP("DecodeBboxV2")
        .FrameworkType(TENSORFLOW)          // 原始框架为Tensorflow
        .OriginOpType("DecodeBboxV2FusionOp")   // 算子在原始框架中的类型,和GenerateFusionResult的SetType的内容保持一致
        .FusionParseParamsFn(DecodeBboxV2ParseParams)  // 用来注册解析融合算子属性的函数
        .ImplyType(ImplyType::TVM);         // 指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子

    如果识别出scope不满足条件,则不融合,可以设置type为kScopeInvalidType然后返回。

      if (scopes.size() != 1) {
        fusion_rlt->SetType(kScopeInvalidType);
        return;
      }
  4. 通过SetName设置融合算子名称。需要注意的是,为保持融合算子名称全局唯一,建议尽量不要自行命名,可以根据scope的名称设置,例如:
      // 设置融合算子的名称
      AscendString scope_name;
      Status ret = scopes[0]->Name(scope_name);
      std::string scope_name_str;
      if (scope_name.GetString() != nullptr) {
          scope_name_str = scope_name.GetString();
      }
      fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str());
  5. 通过SetDescription设置融合算子描述,例如:
      // 设置融合算子的描述
      fusion_rlt->SetDescription("");
完整代码示例为:
void CustomScopeDecodeBboxV2Pass::GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) {
  if (fusion_rlt == nullptr) {
    return;
  }
  if (scopes.size() != 1) {
    fusion_rlt->SetType(kScopeInvalidType);  // 如果识别出scope不满足条件,可以设置type为kScopeInvalidType然后返回
    return;
  }

  // 设置融合算子输入,将transpose的第0个输入作为融合算子的第0个输入, transpose的第一个输入不使用
  fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});
  // 设置融合算子输入,将get_center_coordinates_and_sizes/transpose的第0个输入作为融合算子的第1个输入,get_center_coordinates_and_sizes/transpose的第一个输入不使用
  fusion_rlt->InsertInputs("get_center_coordinates_and_sizes/transpose", {1, kFusionDisableIndex});
  // 设置融合算子输出, 将transpose_1的第0个输出作为融合算子的输出
  fusion_rlt->InsertOutputs("transpose_1", {0});

  // 设置融合算子类型
  fusion_rlt->SetType(kScopeType);
  // 设置融合算子名称
  AscendString scope_name;
  Status ret = scopes[0]->Name(scope_name);
  if (ret != SUCCESS) {
      return;
  }
  std::string scope_name_str;
  if (scope_name.GetString() != nullptr) {
      scope_name_str = scope_name.GetString();
  }
  fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str());
  // 设置融合算子描述
  fusion_rlt->SetDescription("");
  OP_LOGI(kOpType, "Set fusion result successfully.");
  return;
}
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词