Scope融合算子适配插件实现
本节介绍如何实现融合算子适配插件,将基于原始框架的小算子映射成适配昇腾AI处理器的融合算子,并将算子信息注册到GE中。
在融合算子适配插件实现文件(例如decode_bbox_v2_scope_fussion_plugin.cc)中,完成相应功能实现。
Scope融合算子Parser注册流程复用普通算子的注册逻辑,通过REGISTER_CUSTOM_OP宏,按照指定的算子名称完成算子的注册。
REGISTER_CUSTOM_OP("DecodeBboxV2") .FrameworkType(TENSORFLOW) // 原始框架为Tensorflow .OriginOpType("DecodeBboxV2FusionOp") // 算子在原始框架中的类型,和GenerateFusionResult的SetType的内容保持一致 .FusionParseParamsFn(DecodeBboxV2ParseParams) // 用来注册解析融合算子属性的函数 .ImplyType(ImplyType::TVM); // 指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子

Scope融合算子Parser注册和普通算子的差异点在于,注册Parser函数的接口由ParseParamsByOperatorFn变为FusionParseParamsFn(Overload)。原因是两个Parser函数入参不同:
- 普通算子注册函数ParseParamsByOperatorFn的回调函数原型:
using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
输入为框架定义的原始算子Operator类对象。
- Scope融合算子注册函数FusionParseParamsFn(Overload)的回调函数原型:
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
输入为scope内部所有算子的NodeDef;输出为融合算子数据结构,保存融合算子信息。
用户自定义并实现回调函数,完成原始模型中小算子属性到融合算子的属性映射,将结果填到Operator类中。
Status FusionParseParamByOpFunc(const std::vector<ge::Operator> &op_src, ge::Operator &op_dest);
以下是FusionParseParamsFn的实现代码示例,目标是从原始模型中的小算子找到融合算子scales属性信息。
Status ParseFloatFromConstNode(const ge::Operator *node, float &value) { if (node == nullptr) { return FAILED; } ge::Tensor tensor; auto ret = node->GetAttr("value", tensor); if (ret != ge::GRAPH_SUCCESS) { AscendString op_name; ret = node->GetName(op_name); if (ret != ge::GRAPH_SUCCESS) { return FAILED; } OP_LOGE(op_name.GetString(), "Failed to get value from %s", op_name.GetString()); return FAILED; } uint8_t *data_addr = tensor.GetData(); value = *(reinterpret_cast<float *>(data_addr)); return SUCCESS; } // 用户自定义实现回调函数 Status DecodeBboxV2ParseParams(const std::vector<ge::Operator> &inside_nodes, ge::Operator &op_dest) { std::map<std::string, std::string> scales_const_name_map; std::map<string, const ge::Operator *> node_map; for (const auto &node : inside_nodes) { ge::AscendString op_type; ge::graphStatus ret = node.GetOpType(op_type); if (ret != ge::GRAPH_SUCCESS) { return FAILED; } ge::AscendString op_name; ret = node.GetName(op_name); string str_op_name; if (op_name.GetString() != nullptr) { str_op_name = op_name.GetString(); } if (op_type == kBoxesDiv) { if (node.GetInputsSize() < kRealDivInputSize) { OP_LOGE(op_name.GetString(), "Input size of %s is invalid, which is %zu.", kBoxesDiv, node.GetInputsSize()); return FAILED; } ge::AscendString input_unpack_name0; ret = node.GetInputDesc(0).GetName(input_unpack_name0); string str_input_unpack_name0; if (input_unpack_name0.GetString() != nullptr) { str_input_unpack_name0 = input_unpack_name0.GetString(); } ge::AscendString input_unpack_name1; ret = node.GetInputDesc(1).GetName(input_unpack_name1); string str_input_unpack_name1; if (input_unpack_name1.GetString() != nullptr) { str_input_unpack_name1 = input_unpack_name1.GetString(); } if (str_input_unpack_name0.find(kBoxesUnpack) != string::npos) { scales_const_name_map.insert({str_op_name, str_input_unpack_name1 }); } } node_map[str_op_name] = &node; } std::vector<float> scales_list = {1.0, 1.0, 1.0, 1.0}; if (scales_const_name_map.size() != kScaleSize) { ge::AscendString op_name; ge::graphStatus ret = op_dest.GetName(op_name); if (ret != ge::GRAPH_SUCCESS) { return FAILED; } OP_LOGI(op_name.GetString(), "Boxes doesn't need scale."); } else { size_t i = 0; for (const auto &name_pair : scales_const_name_map) { float scale_value = 1.0; auto ret = ParseFloatFromConstNode(node_map[name_pair.second], scale_value); if (ret != SUCCESS) { return ret; } scales_list[i++] = scale_value; } } op_dest.SetAttr("scales", scales_list); return SUCCESS; }
父主题: 融合规则开发