下载
中文
注册

获取xlacompile.patch补丁文件

用户安装完xlacompile.patch补丁,编译成xlacompile工具后,该工具可以将有控制流的V1网络模型转成函数类的V2网络模型。

将如下代码复制到文件中,并另存为xlacompile.patch,然后上传到Linux服务器tensorflow-1.15.0路径下:

---
 WORKSPACE                                      |   7 +
 tensorflow/compiler/aot/BUILD                  |  27 ++++
 tensorflow/compiler/aot/xlacompile_main.cc     | 170 +++++++++++++++++++++
 tensorflow/compiler/tf2xla/tf2xla.cc           |   6 +
 tensorflow/compiler/tf2xla/tf2xla.h            |   4 +
 5 files changed, 195 insertions(+)
 create mode 100644 tensorflow/compiler/aot/xlacompile_main.cc

diff --git a/WORKSPACE b/WORKSPACE
index 74ea14d..d2265f9 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -34,6 +34,13 @@ load(

 bazel_toolchains_repositories()

+http_archive(
+    name = "io_bazel_rules_docker",
+    sha256 = "6287241e033d247e9da5ff705dd6ef526bac39ae82f3d17de1b69f8cb313f9cd",
+    strip_prefix = "rules_docker-0.14.3",
+    urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.14.3/rules_docker-v0.14.3.tar.gz"],
+)
+
 load(
     "@io_bazel_rules_docker//repositories:repositories.bzl",
     container_repositories = "repositories",

diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index f871115..b2620db 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -106,6 +106,33 @@ cc_library(
     ],
 )

+tf_cc_binary(
+    name = "xlacompile",
+    visibility = ["//visibility:public"],
+    deps = [":xlacompile_main"],
+)
+
+cc_library(
+    name = "xlacompile_main",
+    srcs = ["xlacompile_main.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":tfcompile_lib",
+        "//tensorflow/compiler/tf2xla:tf2xla_proto",
+        "//tensorflow/compiler/tf2xla:tf2xla_util",
+        "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/compiler/xla/service:compiler",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 # NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
 # tfcompile.bzl correctly handles usage from outside of the package that it is
 # defined in.

diff --git a/tensorflow/compiler/aot/xlacompile_main.cc b/tensorflow/compiler/aot/xlacompile_main.cc
new file mode 100644
index 0000000..bc795ef
--- /dev/null
+++ b/tensorflow/compiler/aot/xlacompile_main.cc
@@ -0,0 +1,170 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+#include <map>
+
+#include "tensorflow/compiler/aot/flags.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/platform/init_main.h"
+
+namespace tensorflow {
+namespace xlacompile {
+
+const char kUsageHeader[] =
+    "xlacompile performs ahead-of-time compilation of a TensorFlow graph,\n"
+    "resulting in an object file compiled for your target architecture, and a\n"
+    "header file that gives access to the functionality in the object file.\n"
+    "A typical invocation looks like this:\n"
+    "\n"
+    "   $ xlacompile --graph=mygraph.pb --config=config.pbtxt --output=output.pbtxt\n"
+    "\n";
+
+void AppendMainFlags(std::vector<Flag>* flag_list, tfcompile::MainFlags* flags) {
+  const std::vector<Flag> tmp = {
+      {"graph", &flags->graph,
+       "Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
+       "be in the human-readable proto text format, otherwise it is expected "
+       "to be in the proto binary format."},
+      {"config", &flags->config,
+       "Input file containing Config proto. If the file ends in '.pbtxt' it "
+       "is expected to be in the human-readable proto text format, otherwise "
+       "it is expected to be in the proto binary format."},
+      {"output", &flags->out_session_module,
+       "Output session module proto. Will generate '.pb' and '.pbtxt' file."},
+  };
+  flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
+}
+
+Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
+  if (absl::EndsWith(fname, ".pbtxt")) {
+    return ReadTextProto(Env::Default(), fname, proto);
+  } else {
+    return ReadBinaryProto(Env::Default(), fname, proto);
+  }
+}
+
+Status Main(tfcompile::MainFlags& flags) {
+  // Process config.
+  tf2xla::Config config;
+  if (flags.config.empty()) {
+    return errors::InvalidArgument("Must specify --config");
+  }
+  TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
+  TF_RETURN_IF_ERROR(ValidateConfig(config));
+  if (flags.dump_fetch_nodes) {
+    std::set<string> nodes;
+    for (const tf2xla::Fetch& fetch : config.fetch()) {
+      nodes.insert(fetch.id().node_name());
+    }
+    std::cout << absl::StrJoin(nodes, ",");
+    return Status::OK();
+  }
+
+  // Read and initialize the graph.
+  if (flags.graph.empty()) {
+    return errors::InvalidArgument("Must specify --graph");
+  }
+  if (flags.out_session_module.empty()) {
+    return errors::InvalidArgument("Must specify --output");
+  }
+
+  string output_pb_bin = flags.out_session_module + ".pb";
+  string output_pb_txt = flags.out_session_module + ".pbtxt";
+  if (output_pb_bin == flags.config || output_pb_bin == flags.graph ||
+      output_pb_txt == flags.config || output_pb_txt == flags.graph) {
+    return errors::InvalidArgument("Must different --config --graph --output");
+  }
+
+  GraphDef graph_def;
+  TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
+  std::unique_ptr<Graph> graph;
+  TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, graph));
+
+  std::map<string, string> arg_name_maps;
+  GraphDef new_graph_def;
+  graph->ToGraphDef(&new_graph_def);
+  // Delete _class attribute for: expects to be colocated with unknown node
+  for (int i = 0; i < new_graph_def.node_size(); ++i) {
+    NodeDef *node = new_graph_def.mutable_node(i);
+    node->mutable_attr()->erase("_class");
+    if (node->op() == "_Retval") {
+      node->set_name(absl::StrCat(node->attr().at("index").i(), "_Retval"));
+    }
+    if (node->op() == "_Arg") {
+      const string name = node->name();
+      node->set_name(absl::StrCat(node->attr().at("index").i(), "_Arg"));
+      arg_name_maps[name] = node->name();
+    }
+  }
+
+  for (int i = 0; i < new_graph_def.node_size() && !arg_name_maps.empty(); ++i) {
+    NodeDef *node = new_graph_def.mutable_node(i);
+    for (int j = 0; j < node->input_size(); ++j) {
+      auto it = arg_name_maps.find(node->input(j));
+      if (it != arg_name_maps.end()) {
+        *node->mutable_input(j) = it->second;
+      }
+    }
+  }
+
+  TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), output_pb_bin, new_graph_def));
+  std::cerr << "Successfully convert: " << output_pb_bin << "\n";
+  TF_RETURN_IF_ERROR(WriteTextProto(Env::Default(), output_pb_txt, new_graph_def));
+  std::cerr << "Successfully convert: " << output_pb_txt << "\n";
+  return Status::OK();
+}
+
+}  // end namespace xlacompile
+}  // end namespace tensorflow
+
+int main(int argc, char** argv) {
+  tensorflow::tfcompile::MainFlags flags;
+  flags.target_triple = "x86_64-pc-linux";
+  flags.out_function_object = "out_model.o";
+  flags.out_metadata_object = "out_helper.o";
+  flags.out_header = "out.h";
+  flags.entry_point = "entry";
+
+  std::vector<tensorflow::Flag> flag_list;
+  tensorflow::xlacompile::AppendMainFlags(&flag_list, &flags);
+
+  tensorflow::string usage = tensorflow::xlacompile::kUsageHeader;
+  usage += tensorflow::Flags::Usage(argv[0], flag_list);
+  if (argc > 1 && absl::string_view(argv[1]) == "--help") {
+    std::cerr << usage << "\n";
+    return 0;
+  }
+  bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+  QCHECK(parsed_flags_ok) << "\n" << usage;
+
+  tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+  QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
+                       "other than flags\n\n"
+                    << usage;
+  tensorflow::Status status = tensorflow::xlacompile::Main(flags);
+  if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
+    std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
+              << usage;
+    return 1;
+  } else {
+    TF_QCHECK_OK(status);
+  }
+  return 0;
+}

diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 3c2b256..3872776 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -410,4 +410,10 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def,
   return Status::OK();
 }

+Status ConvertGraphDefToXla(const GraphDef &graph_def,
+                            const tf2xla::Config &config,
+                            std::unique_ptr<Graph> &graph) {
+  return InitGraph(graph_def, config, &graph);
+}
+
 }  // namespace tensorflow

diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h
index 432a12a..969500c 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.h
+++ b/tensorflow/compiler/tf2xla/tf2xla.h
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"

 namespace tensorflow {

@@ -34,6 +35,9 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def,
                             const tf2xla::Config& config, xla::Client* client,
                             xla::XlaComputation* computation);

+Status ConvertGraphDefToXla(const GraphDef &graph_def,
+                            const tf2xla::Config &config,
+                            std::unique_ptr<Graph> &graph);
 }  // namespace tensorflow

 #endif  // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_

--
1.8.3.1