Scope融合算子适配插件实现
本节介绍如何实现融合算子适配插件,将基于原始框架的小算子映射成适配昇腾AI处理器的融合算子,并将算子信息注册到GE中。
在融合算子适配插件实现文件(例如decode_bbox_v2_scope_fussion_plugin.cc)中,完成相应功能实现。
Scope融合算子Parser注册流程复用普通算子的注册逻辑,通过REGISTER_CUSTOM_OP宏,按照指定的算子名称完成算子的注册。
REGISTER_CUSTOM_OP("DecodeBboxV2")
.FrameworkType(TENSORFLOW) // 原始框架为Tensorflow
.OriginOpType("DecodeBboxV2FusionOp") // 算子在原始框架中的类型,和GenerateFusionResult的SetType的内容保持一致
.FusionParseParamsFn(DecodeBboxV2ParseParams) // 用来注册解析融合算子属性的函数
.ImplyType(ImplyType::TVM); // 指定算子的实现方式,ImplyType::TVM表示该算子是TBE算子
Scope融合算子Parser注册和普通算子的差异点在于,注册Parser函数的接口由ParseParamsByOperatorFn变为FusionParseParamsFn(Overload)。原因是两个Parser函数入参不同:
- 普通算子注册函数ParseParamsByOperatorFn的回调函数原型:
using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
输入为框架定义的原始算子Operator类对象。
- Scope融合算子注册函数FusionParseParamsFn(Overload)的回调函数原型:
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
输入为scope内部所有算子的NodeDef;输出为融合算子数据结构,保存融合算子信息。
用户自定义并实现回调函数,完成原始模型中小算子属性到融合算子的属性映射,将结果填到Operator类中。
Status FusionParseParamByOpFunc(const std::vector<ge::Operator> &op_src, ge::Operator &op_dest);
以下是FusionParseParamsFn的实现代码示例,目标是从原始模型中的小算子找到融合算子scales属性信息。
Status ParseFloatFromConstNode(const ge::Operator *node, float &value) {
if (node == nullptr) {
return FAILED;
}
ge::Tensor tensor;
auto ret = node->GetAttr("value", tensor);
if (ret != ge::GRAPH_SUCCESS) {
AscendString op_name;
ret = node->GetName(op_name);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
OP_LOGE(op_name.GetString(), "Failed to get value from %s", op_name.GetString());
return FAILED;
}
uint8_t *data_addr = tensor.GetData();
value = *(reinterpret_cast<float *>(data_addr));
return SUCCESS;
}
// 用户自定义实现回调函数
Status DecodeBboxV2ParseParams(const std::vector<ge::Operator> &inside_nodes, ge::Operator &op_dest) {
std::map<std::string, std::string> scales_const_name_map;
std::map<string, const ge::Operator *> node_map;
for (const auto &node : inside_nodes) {
ge::AscendString op_type;
ge::graphStatus ret = node.GetOpType(op_type);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
ge::AscendString op_name;
ret = node.GetName(op_name);
string str_op_name;
if (op_name.GetString() != nullptr) {
str_op_name = op_name.GetString();
}
if (op_type == kBoxesDiv) {
if (node.GetInputsSize() < kRealDivInputSize) {
OP_LOGE(op_name.GetString(), "Input size of %s is invalid, which is %zu.", kBoxesDiv, node.GetInputsSize());
return FAILED;
}
ge::AscendString input_unpack_name0;
ret = node.GetInputDesc(0).GetName(input_unpack_name0);
string str_input_unpack_name0;
if (input_unpack_name0.GetString() != nullptr) {
str_input_unpack_name0 = input_unpack_name0.GetString();
}
ge::AscendString input_unpack_name1;
ret = node.GetInputDesc(1).GetName(input_unpack_name1);
string str_input_unpack_name1;
if (input_unpack_name1.GetString() != nullptr) {
str_input_unpack_name1 = input_unpack_name1.GetString();
}
if (str_input_unpack_name0.find(kBoxesUnpack) != string::npos) {
scales_const_name_map.insert({str_op_name, str_input_unpack_name1 });
}
}
node_map[str_op_name] = &node;
}
std::vector<float> scales_list = {1.0, 1.0, 1.0, 1.0};
if (scales_const_name_map.size() != kScaleSize) {
ge::AscendString op_name;
ge::graphStatus ret = op_dest.GetName(op_name);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
OP_LOGI(op_name.GetString(), "Boxes doesn't need scale.");
} else {
size_t i = 0;
for (const auto &name_pair : scales_const_name_map) {
float scale_value = 1.0;
auto ret = ParseFloatFromConstNode(node_map[name_pair.second], scale_value);
if (ret != SUCCESS) {
return ret;
}
scales_list[i++] = scale_value;
}
}
op_dest.SetAttr("scales", scales_list);
return SUCCESS;
}
父主题: 融合规则开发