获取xlacompile.patch补丁文件
用户安装完xlacompile.patch补丁,编译成xlacompile工具后,该工具可以将有控制流的V1网络模型转成函数类的V2网络模型。
将如下代码复制到文件中,并另存为xlacompile.patch,然后上传到Linux服务器tensorflow-1.15.0路径下:
--- WORKSPACE | 7 + tensorflow/compiler/aot/BUILD | 27 ++++ tensorflow/compiler/aot/xlacompile_main.cc | 170 +++++++++++++++++++++ tensorflow/compiler/tf2xla/tf2xla.cc | 6 + tensorflow/compiler/tf2xla/tf2xla.h | 4 + 5 files changed, 195 insertions(+) create mode 100644 tensorflow/compiler/aot/xlacompile_main.cc diff --git a/WORKSPACE b/WORKSPACE index 74ea14d..d2265f9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -34,6 +34,13 @@ load( bazel_toolchains_repositories() +http_archive( + name = "io_bazel_rules_docker", + sha256 = "6287241e033d247e9da5ff705dd6ef526bac39ae82f3d17de1b69f8cb313f9cd", + strip_prefix = "rules_docker-0.14.3", + urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.14.3/rules_docker-v0.14.3.tar.gz"], +) + load( "@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories", diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index f871115..b2620db 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -106,6 +106,33 @@ cc_library( ], ) +tf_cc_binary( + name = "xlacompile", + visibility = ["//visibility:public"], + deps = [":xlacompile_main"], +) + +cc_library( + name = "xlacompile_main", + srcs = ["xlacompile_main.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tfcompile_lib", + "//tensorflow/compiler/tf2xla:tf2xla_proto", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + # NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that # tfcompile.bzl correctly handles usage from outside of the package that it is # defined in. diff --git a/tensorflow/compiler/aot/xlacompile_main.cc b/tensorflow/compiler/aot/xlacompile_main.cc new file mode 100644 index 0000000..bc795ef --- /dev/null +++ b/tensorflow/compiler/aot/xlacompile_main.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <string> +#include <utility> +#include <vector> +#include <map> + +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/platform/init_main.h" + +namespace tensorflow { +namespace xlacompile { + +const char kUsageHeader[] = + "xlacompile performs ahead-of-time compilation of a TensorFlow graph,\n" + "resulting in an object file compiled for your target architecture, and a\n" + "header file that gives access to the functionality in the object file.\n" + "A typical invocation looks like this:\n" + "\n" + " $ xlacompile --graph=mygraph.pb --config=config.pbtxt --output=output.pbtxt\n" + "\n"; + +void AppendMainFlags(std::vector<Flag>* flag_list, tfcompile::MainFlags* flags) { + const std::vector<Flag> tmp = { + {"graph", &flags->graph, + "Input GraphDef file. If the file ends in '.pbtxt' it is expected to " + "be in the human-readable proto text format, otherwise it is expected " + "to be in the proto binary format."}, + {"config", &flags->config, + "Input file containing Config proto. If the file ends in '.pbtxt' it " + "is expected to be in the human-readable proto text format, otherwise " + "it is expected to be in the proto binary format."}, + {"output", &flags->out_session_module, + "Output session module proto. Will generate '.pb' and '.pbtxt' file."}, + }; + flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); +} + +Status ReadProtoFile(const string& fname, protobuf::Message* proto) { + if (absl::EndsWith(fname, ".pbtxt")) { + return ReadTextProto(Env::Default(), fname, proto); + } else { + return ReadBinaryProto(Env::Default(), fname, proto); + } +} + +Status Main(tfcompile::MainFlags& flags) { + // Process config. + tf2xla::Config config; + if (flags.config.empty()) { + return errors::InvalidArgument("Must specify --config"); + } + TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); + TF_RETURN_IF_ERROR(ValidateConfig(config)); + if (flags.dump_fetch_nodes) { + std::set<string> nodes; + for (const tf2xla::Fetch& fetch : config.fetch()) { + nodes.insert(fetch.id().node_name()); + } + std::cout << absl::StrJoin(nodes, ","); + return Status::OK(); + } + + // Read and initialize the graph. + if (flags.graph.empty()) { + return errors::InvalidArgument("Must specify --graph"); + } + if (flags.out_session_module.empty()) { + return errors::InvalidArgument("Must specify --output"); + } + + string output_pb_bin = flags.out_session_module + ".pb"; + string output_pb_txt = flags.out_session_module + ".pbtxt"; + if (output_pb_bin == flags.config || output_pb_bin == flags.graph || + output_pb_txt == flags.config || output_pb_txt == flags.graph) { + return errors::InvalidArgument("Must different --config --graph --output"); + } + + GraphDef graph_def; + TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); + std::unique_ptr<Graph> graph; + TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, graph)); + + std::map<string, string> arg_name_maps; + GraphDef new_graph_def; + graph->ToGraphDef(&new_graph_def); + // Delete _class attribute for: expects to be colocated with unknown node + for (int i = 0; i < new_graph_def.node_size(); ++i) { + NodeDef *node = new_graph_def.mutable_node(i); + node->mutable_attr()->erase("_class"); + if (node->op() == "_Retval") { + node->set_name(absl::StrCat(node->attr().at("index").i(), "_Retval")); + } + if (node->op() == "_Arg") { + const string name = node->name(); + node->set_name(absl::StrCat(node->attr().at("index").i(), "_Arg")); + arg_name_maps[name] = node->name(); + } + } + + for (int i = 0; i < new_graph_def.node_size() && !arg_name_maps.empty(); ++i) { + NodeDef *node = new_graph_def.mutable_node(i); + for (int j = 0; j < node->input_size(); ++j) { + auto it = arg_name_maps.find(node->input(j)); + if (it != arg_name_maps.end()) { + *node->mutable_input(j) = it->second; + } + } + } + + TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), output_pb_bin, new_graph_def)); + std::cerr << "Successfully convert: " << output_pb_bin << "\n"; + TF_RETURN_IF_ERROR(WriteTextProto(Env::Default(), output_pb_txt, new_graph_def)); + std::cerr << "Successfully convert: " << output_pb_txt << "\n"; + return Status::OK(); +} + +} // end namespace xlacompile +} // end namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::tfcompile::MainFlags flags; + flags.target_triple = "x86_64-pc-linux"; + flags.out_function_object = "out_model.o"; + flags.out_metadata_object = "out_helper.o"; + flags.out_header = "out.h"; + flags.entry_point = "entry"; + + std::vector<tensorflow::Flag> flag_list; + tensorflow::xlacompile::AppendMainFlags(&flag_list, &flags); + + tensorflow::string usage = tensorflow::xlacompile::kUsageHeader; + usage += tensorflow::Flags::Usage(argv[0], flag_list); + if (argc > 1 && absl::string_view(argv[1]) == "--help") { + std::cerr << usage << "\n"; + return 0; + } + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + tensorflow::Status status = tensorflow::xlacompile::Main(flags); + if (status.code() == tensorflow::error::INVALID_ARGUMENT) { + std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" + << usage; + return 1; + } else { + TF_QCHECK_OK(status); + } + return 0; +} diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3c2b256..3872776 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -410,4 +410,10 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def, return Status::OK(); } +Status ConvertGraphDefToXla(const GraphDef &graph_def, + const tf2xla::Config &config, + std::unique_ptr<Graph> &graph) { + return InitGraph(graph_def, config, &graph); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index 432a12a..969500c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" namespace tensorflow { @@ -34,6 +35,9 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation); +Status ConvertGraphDefToXla(const GraphDef &graph_def, + const tf2xla::Config &config, + std::unique_ptr<Graph> &graph); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ -- 1.8.3.1
父主题: FAQ