修改Graph
如果用户想要直接优化图的结构,比如将某些特定子图替换成一个大算子,以减少计算步骤、外存访问、调度时间等,或者在某些算子中间添加一个算子,此时可以通过本节内容直接将图直接修改成期望的结构。
功能介绍
以在算子A和算子B之间添加算子C为例,修改Graph,涉及的主要接口为:
开发示例
- 包含的头文件。
1 2 3 4
#include "graph.h" #include "ascend_string.h" #include "ge_ir_build.h" #include "gnode.h"
- (可选步骤)修改图之前,可以先调用aclgrphDumpGraph把Graph dump到本地,查看Graph信息。
需要注意的是,aclgrphDumpGraph接口必须在SetInputs接口和SetOutputs接口之后调用,例如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
string op_name = "tc_ge_openpass_0001"; Graph graph(op_name); auto data = op::Data("data").set_attr_index(0); TensorDesc data_desc2(ge::Shape({3, 3, 3, 3}), FORMAT_NHWC, DT_FLOAT); data.update_input_desc_x(data_desc2); data.update_output_desc_y(data_desc2); auto matrixinverse = op::MatrixInverse("MatrixInverse").set_input_x(data); auto square1 = op::Square("square1").set_input_x(matrixinverse); std::vector<Operator> inputs{data}; std::vector<Operator> outputs{data,square1}; graph.SetInputs(inputs).SetOutputs(outputs); std::map<std::string, std::string> init_options = { {ge::ir_option::SOC_VERSION,"xxx"} }; auto ret = aclgrphBuildInitialize(init_options); EXPECT_EQ(ret, GRAPH_SUCCESS); std::cout << "BuildInitialize before infershape Success." << std::endl; size_t filesize =24; const char* file = "tc_ge_openpass_0001_dump"; ret = ge::aclgrphDumpGraph(graph,file,filesize); if(ret != GRAPH_SUCCESS) { std::cout<<"dump graph faied."<<std::endl; } ACL_LOG("AclgrphDumpGraph,size[%d]",filesize); ret = AclgrphBuildModel(graph,op_name); if(ret != GRAPH_SUCCESS) { std::cout<<"AclgrphBuildModel faied."<<std::endl; }
- 在算子A和算子B之间增加算子C,比如在const和add算子之间插入abs。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
GNode src_node; GNode dst_node; std::vector<GNode> nodes = graph.GetAllNodes(); for(auto &node : nodes) { ge::AscendString name; node.GetName(name); std::string node_name(name.GetString()); if(node_name == CONST) { src_node = node;} else if(node_name == ADD) { dst_node = node;} } graph.RemoveEdge(src_node, 0, dst_node, 0); auto abs = op::Abs("input3_abs"); GNode node_abs = graph.AddNodeByOp(abs); TensorDesc output_tensor_desc; src_node.GetOutputDesc(0, output_tensor_desc); abs.UpdateInputDesc(0, output_tensor_desc); abs.UpdateOutputDesc(0, output_tensor_desc); graph.AddDataEdge(src_node, 0, node_abs, 0); graph.AddDataEdge(node_abs, 0, dst_node, 0);
- 调用GetAllNodes找到const算子和add算子。
- 调用RemoveEdge删除const算子和add算子的连边(数据边或控制边)。
- 参考使用算子原型衍生接口定义算子,创建Operator类算子abs(也可以调用OperatorFactory::CreateOperator创建算子)。
- 调用AddNodeByOp创建GNode类算子abs。
创建完算子后,可以根据需要更新该算子的input和output TensorDesc,一般根据源节点的Output TensorDesc更新算子abs的Input TensorDesc和Output TensorDesc。如果不更新,系统会设置默认值,在模型编译时对Tensor Shape,type进行推导。
- 调用AddDataEdge添加const算子和abs算子,abs算子和add算子之间的连边。如果有控制边,再调用AddControlEdge添加控制边。
如果在A与B插入多个算子,比如,A->C->D->B,参考以上步骤,分别执行操作A->C, C->D, D->B。
- 删除算子A和算子B之间的C算子,比如删除算子const和add之间的abs。
1 2
graph.RemoveNode(node_abs); graph.AddDataEdge(src_node, 0, dst_node, 0);
- 此外,如果需要查询GNode的信息,可以参考GNode类提供的方法。