适配插件开发(Caffe框架)
简介
您可以参考本章节进行算子适配插件的开发,将基于第三方框架的算子映射成适配昇腾AI处理器的算子,将算子信息注册到Graph Engine(简称:GE)中。基于Caffe框架的网络运行时,首先会加载并调用GE中的插件信息,将原始框架网络中的算子进行解析并映射成适配昇腾AI处理器中的算子。
下文我们将适配昇腾AI处理器的算子称为CANN算子。
Caffe框架算子到CANN算子的一对多、多对多、多对一映射暂不支持。
原理介绍
算子插件的实现包含CANN算子类型的注册、原始框架中算子类型的注册以及原始框架中算子属性到CANN算子属性的映射,算子的映射通过Parser模块完成。插件在整网络运行场景下的实现流程如图1所示。
- 首先GE接收到第三方框架的原始网络模型,并进行初始化,网络模型的拓扑图我们简称为图。
- GE从Register注册模块中加载算子插件生成的.so文件,在CANN软件安装后文件存储路径的“opp/built-in/framework/”路径中。
- 读取算子插件.so中的算子相关信息,并将其注册到算子插件的map文件中(所有算子插件的相关信息都会以map的形式存储到一个文件中)。
- GE向Parser模块发送调用Parser方法的请求。
- Parser模块根据算子类型(OpType)从算子插件的map文件中取出对应的Parser函数,并返回实现函数ParseParamsByOperatorFn给Parser模块,Parser模块根据实现函数将第三方网络算子中的属性映射到CANN算子的属性,即算子原型中的属性定义,从而完成第三方网络中算子到CANN算子的映射。
- 后续会进行图准备、图拆分及图优化等一系列操作,最终生成适配昇腾AI处理器的网络模型。
插件实现
GE提供REGISTER_CUSTOM_OP宏,按照指定的算子名称完成算子的注册。
#include "register/register.h" #include "graph/operator.h" namespace domi { REGISTER_CUSTOM_OP("OpType") .FrameworkType(CAFFE) .OriginOpType("OriginOpType") .ParseParamsByOperatorFn(ParseParamByOpFunc) // 用来注册解析算子属性的函数 .ImplyType(ImplyType::TVM); // TBE算子:ImplyType::TVM;AI CPU算子:ImplyType::AI_CPU }
- 在代码实现文件顶部使用预编译命令“#include”将插件实现函数相关的头文件包含到插件实现文件中。
register.h存储在CANN软件安装后文件存储路径的“include/register/”目录下,包含该头文件,可使用算子注册相关类,调用算子注册相关的接口。
operator.h(可选),存储在CANN软件安装后文件存储路径的“include/graph/”目录下,包含该头文件,可以使用Operator类相关接口,获取算子输入输出及属性等算子信息。
- REGISTER_CUSTOM_OP:注册自定义算子,OpType为注册到GE中的算子类型,需要与算子原型注册中的OpType保持一致。
- FrameworkType:CAFFE代表原始框架为Caffe。
- OriginOpType:算子在原始框架中的类型。
- ParseParamsByOperatorFn(ParseParamByOpFunc):用来注册解析算子属性的函数,需要用户自定义实现回调函数ParseParamByOpFunc。
回调函数ParseParamByOpFunc的声明如下所示:
Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest)
- ParseParamByOpFunc:函数名称,用户自定义,需要保持唯一。
- op_src:Caffe框架定义的Operator类对象,包含Caffe模型中自定义的算子属性信息,定义来源于Caffe模型的proto文件,若用户自定义的算子在caffe.proto文件中未定义,则需要参考配套版本开源Sample中“cplusplus\level1_single_api\4_op_dev\1_custom_op”目录下README的“算子工程编译”章节增加算子定义,GetAttr接口会读取此文件中算子定义进行解析。
- op_dest:CANN算子数据结构,保存算子信息,Operator类的详细描述请参见Operator类。
ParseParamByOpFunc函数的实现如下所示:
用户首先需要调用Operator类的GetAttr接口获取op_src对象的属性值,然后调用SetAttr接口将获取到的属性值赋给op_dest对象。
GetAttr接口支持不同类型的属性值,例如:
- 对于caffe.proto定义中的int64类型参数,可使用GetAttr(const char *name, int64_t &attr_value)接口获取其属性值。
- 对于caffe.proto定义中的enum类型参数,可使用GetAttr(const char *name, int32_t &attr_value)接口获取其属性值。
- 对于caffe.proto定义中的repeated float类型参数,可使用GetAttr(const char *name, std::vector<float> &attr_value)接口获取其属性值。
- 对于caffe.proto定义中的repeated message类型参数,可使用GetAttr(const char *name, ge::AscendString &attr_value)接口获取其属性值,由于repeated message类型的attr_value在图引擎处理过程中是按照json格式存储的,所以需要将AscendString类型的属性值转换为String类型,再将String类型转换为json格式进行解析。
例如:caffe.proto中定义如下所示:
message BiasParameter { repeated BiasStruct bias_struct = 1; } message BiasStruct { optional uint32 offset = 1; repeated uint32 width = 2; };
使用GetAttr("bias_struct", attr_value)获取attr_value(AscendString类型), 再将其转化为attr_value(string类型) , 最后再将string类型转化为json格式去访问其中的字段。
假设prototxt中算子参数如下:
layer { name: "bias" top: "out" bottom: "data" type: "Bias" bias_param { bias_struct { offset : 2 width: 8 width: 10 } bias_struct { offset : 1 width: 20 } } }
则转换为json格式后的数据如下:
{ "bias_struct": [ { "offset": 2, "width": [ 8, 10 ] } { "offset": 1, "width": [ 20 ] } ] }
开发者可自由选择将string转为json的方式,下面以使用json.hpp为例进行说明,开发者可将json.hpp放在工程可以找到的任意路径下,使用时包含头文件即可。
ParseParamByOpFunc函数中访问json字段代码如下:
if (ge::GRAPH_SUCCESS == op_src.GetAttr("bias_struct", bias_struct_val)) { std::string bias_struct_str = bias_struct_val.GetString(); // convert to json bias_struc_json = nlohmann::json::parse(bias_struct); for (int i = 0; i < bias_struc_json["bias_struct"].size(); i++) { nlohmann::json bias_value = bias_struc_json["bias_struct"][i]; uint32_t value = bias_value["offset"].get<uint32_t>(); for (int idx = 0; idx < bias_value["width"].size(); idx++) { uint32_t value = bias_value["width"][idx].get<uint32_t>(); } } }
- 当前版本GetAttr与SetAttr接口不支持对custom.proto文件中数据类型为double和uint64的字段进行解析。
- 使用ATC工具执行模型转换时,对属性的获取情况不会进行强校验。所以进行算子适配插件实现时,若用户调用GetAttr失败,建议根据算子实际情况增加相应的处理逻辑,例如,针对必选属性,可返回失败,针对可选属性,可设置默认值。
- 不支持Convolution和InnerProduct算子的权重解析,所以当前版本不支持用户重定义Convolution和InnerProduct算子。
- ImplyType:指定算子的实现方式。ImplyType::TVM表示该算子是TBE算子;ImplyType::AI_CPU表示该算子是AI CPU算子。