#include "register/register.h" #include "graph/operator.h" #include "json.hpp" namespace domi { Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest) { //... } REGISTER_CUSTOM_OP("OpType") .FrameworkType(ONNX) .OriginOpType("OriginOpType") .ParseParamsByOperatorFn(ParseParamByOpFunc) //用来注册解析算子属性的函数 .ImplyType(ImplyType::TVM); // Ascend C算子实现类型设置为TVM }
Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest)
开发者需要在回调函数中实现属性的解析和映射,具体实现方式如下:
ONNX原始模型中,属性为repeated message类型,对于repeated message类型的参数,可使用GetAttr(const char *name, ge::AscendString &attr_value)接口获取其属性值,然后将AscendString类型的属性值转换为String类型,再将其转换为json格式进行属性字段的解析。
实现如下所示:
Status ParseParamLeakyReluAscend(const ge::Operator& op_src, ge::Operator& op_dest) { float negative_slope = 0.01f; string negative_slope_str; AscendString attrs_string; // 使用固定属性名称“attribute”获取ONNX算子中的属性,并赋值给AscendString类型对象 if (ge::GRAPH_SUCCESS == op_src.GetAttr("attribute", attrs_string)) { // 转换为json格式 json attrs = json::parse(attrs_string.GetString()); for (json attr : attrs["attribute"]) { if (attr["name"] == "alpha" && attr["type"] == kTypeFloat) { negative_slope_str = attr["f"]; // float type in json has accuracy loss, so we use string type to store it negative_slope = atof(negative_slope_str.c_str()); } } } op_dest.SetAttr("negative_slope", negative_slope); return SUCCESS; }