下载
中文
注册

Scope融合算子适配插件实现

本节介绍如何实现融合算子适配插件,将基于原始框架的小算子映射成适配昇腾AI处理器的融合算子,并将算子信息注册到GE中。

在融合算子适配插件实现文件(例如decode_bbox_v2_scope_fussion_plugin.cc)中,完成相应功能实现。

Scope融合算子Parser注册流程复用普通算子的注册逻辑,通过REGISTER_CUSTOM_OP宏,按照指定的算子名称完成算子的注册。
REGISTER_CUSTOM_OP("DecodeBboxV2")
    .FrameworkType(TENSORFLOW)              // 原始框架为Tensorflow
    .OriginOpType("DecodeBboxV2FusionOp")   // 算子在原始框架中的类型,和GenerateFusionResult的SetType的内容保持一致
    .FusionParseParamsFn(DecodeBboxV2ParseParams)  // 用来注册解析融合算子属性的函数
    .ImplyType(ImplyType::TVM);             // 指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子

REGISTER_CUSTOM_OP宏、ParseParamsByOperatorFn、FusionParseParamsFn(Overload)等接口和其他更多介绍请参考基础数据结构和接口中的"ge命名空间>OpRegistrationData类"章节,本节仅介绍Scope融合算子Parser注册和普通算子注册的差异点。

Scope融合算子Parser注册和普通算子的差异点在于,注册Parser函数的接口由ParseParamsByOperatorFn变为FusionParseParamsFn(Overload)。原因是两个Parser函数入参不同:

  • 普通算子注册函数ParseParamsByOperatorFn的回调函数原型:
    using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;

    输入为框架定义的原始算子Operator类对象。

  • Scope融合算子注册函数FusionParseParamsFn(Overload)的回调函数原型:
    using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;

    输入为scope内部所有算子的NodeDef;输出为融合算子数据结构,保存融合算子信息。

用户自定义并实现回调函数,完成原始模型中小算子属性到融合算子的属性映射,将结果填到Operator类中。

Status FusionParseParamByOpFunc(const std::vector<ge::Operator> &op_src,  ge::Operator &op_dest);

以下是FusionParseParamsFn的实现代码示例,目标是从原始模型中的小算子找到融合算子scales属性信息。

Status ParseFloatFromConstNode(const ge::Operator *node, float &value) {
  if (node == nullptr) {
    return FAILED;
  }
  ge::Tensor tensor;
  auto ret = node->GetAttr("value", tensor);
  if (ret != ge::GRAPH_SUCCESS) {
    AscendString op_name;
    ret = node->GetName(op_name);
    if (ret != ge::GRAPH_SUCCESS) {
      return FAILED;
    }
    OP_LOGE(op_name.GetString(), "Failed to get value from %s", op_name.GetString());
    return FAILED;
  }
  uint8_t *data_addr = tensor.GetData();
  value = *(reinterpret_cast<float *>(data_addr));
  return SUCCESS;
}

// 用户自定义实现回调函数
Status DecodeBboxV2ParseParams(const std::vector<ge::Operator> &inside_nodes, ge::Operator &op_dest) {
  std::map<std::string, std::string> scales_const_name_map;
  std::map<string, const ge::Operator *> node_map;
  for (const auto &node : inside_nodes) {
    ge::AscendString op_type;
    ge::graphStatus ret = node.GetOpType(op_type);
    if (ret != ge::GRAPH_SUCCESS) {
      return FAILED;
    }
    ge::AscendString op_name;
    ret = node.GetName(op_name);
    string str_op_name;
    if (op_name.GetString() != nullptr) {
      str_op_name = op_name.GetString();
    }
    if (op_type == kBoxesDiv) {
      if (node.GetInputsSize() < kRealDivInputSize) {
        OP_LOGE(op_name.GetString(), "Input size of %s is invalid, which is %zu.", kBoxesDiv, node.GetInputsSize());
        return FAILED;
      }
      ge::AscendString input_unpack_name0;
      ret = node.GetInputDesc(0).GetName(input_unpack_name0);
      string str_input_unpack_name0;
      if (input_unpack_name0.GetString() != nullptr) {
        str_input_unpack_name0 = input_unpack_name0.GetString();
      }
      ge::AscendString input_unpack_name1;
      ret = node.GetInputDesc(1).GetName(input_unpack_name1);
      string str_input_unpack_name1;
      if (input_unpack_name1.GetString() != nullptr) {
        str_input_unpack_name1 = input_unpack_name1.GetString();
      }
      if (str_input_unpack_name0.find(kBoxesUnpack) != string::npos) {
        scales_const_name_map.insert({str_op_name, str_input_unpack_name1 });
      }
    }
    node_map[str_op_name] = &node;
  }

  std::vector<float> scales_list = {1.0, 1.0, 1.0, 1.0};
  if (scales_const_name_map.size() != kScaleSize) {
    ge::AscendString op_name;
    ge::graphStatus ret = op_dest.GetName(op_name);
    if (ret != ge::GRAPH_SUCCESS) {
      return FAILED;
    }
    OP_LOGI(op_name.GetString(), "Boxes doesn't need scale.");
  } else {
    size_t i = 0;
    for (const auto &name_pair : scales_const_name_map) {
      float scale_value = 1.0;
      auto ret = ParseFloatFromConstNode(node_map[name_pair.second], scale_value);
      if (ret != SUCCESS) {
        return ret;
      }
      scales_list[i++] = scale_value;
    }
  }
  op_dest.SetAttr("scales", scales_list);
  return SUCCESS;
}