mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Dump the computation's SessionModule as part of the tf_compile rule.
PiperOrigin-RevId: 172946149
This commit is contained in:
parent
ebcae4a5e3
commit
8ff33271ea
|
|
@ -97,11 +97,11 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
|
||||||
&computation,
|
&computation,
|
||||||
&compile_result->has_context_arg));
|
&compile_result->has_context_arg));
|
||||||
if (!flags.debug_dir.empty()) {
|
if (!flags.out_session_module.empty()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
|
||||||
computation.Snapshot());
|
computation.Snapshot());
|
||||||
string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb");
|
TF_RETURN_IF_ERROR(
|
||||||
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module));
|
WriteBinaryProto(Env::Default(), flags.out_session_module, *module));
|
||||||
}
|
}
|
||||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||||
"fetch nodes will be dumped to stdout in a comma-separated list. "
|
"fetch nodes will be dumped to stdout in a comma-separated list. "
|
||||||
"Typically used to format arguments for other tools, e.g. "
|
"Typically used to format arguments for other tools, e.g. "
|
||||||
"freeze_graph."},
|
"freeze_graph."},
|
||||||
{"debug_dir", &flags->debug_dir,
|
|
||||||
"Specifies a directory to dump debugging information, including "
|
|
||||||
"rewritten graphs and the XLA HLO module."},
|
|
||||||
// Flags controlling the XLA ahead-of-time compilation, that correspond to
|
// Flags controlling the XLA ahead-of-time compilation, that correspond to
|
||||||
// the fields of xla::cpu::CpuAotCompilationOptions.
|
// the fields of xla::cpu::CpuAotCompilationOptions.
|
||||||
//
|
//
|
||||||
|
|
@ -64,6 +61,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||||
"namespaces are given, within the global namespace."},
|
"namespaces are given, within the global namespace."},
|
||||||
{"out_object", &flags->out_object, "Output object file name."},
|
{"out_object", &flags->out_object, "Output object file name."},
|
||||||
{"out_header", &flags->out_header, "Output header file name."},
|
{"out_header", &flags->out_header, "Output header file name."},
|
||||||
|
{"out_session_module", &flags->out_session_module,
|
||||||
|
"Output session module proto."},
|
||||||
{"gen_name_to_index", &flags->gen_name_to_index,
|
{"gen_name_to_index", &flags->gen_name_to_index,
|
||||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||||
{"gen_program_shape", &flags->gen_program_shape,
|
{"gen_program_shape", &flags->gen_program_shape,
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ struct MainFlags {
|
||||||
string graph;
|
string graph;
|
||||||
string config;
|
string config;
|
||||||
bool dump_fetch_nodes = false;
|
bool dump_fetch_nodes = false;
|
||||||
string debug_dir;
|
|
||||||
string target_triple;
|
string target_triple;
|
||||||
string target_cpu;
|
string target_cpu;
|
||||||
string target_features;
|
string target_features;
|
||||||
|
|
@ -37,6 +36,7 @@ struct MainFlags {
|
||||||
string cpp_class;
|
string cpp_class;
|
||||||
string out_object;
|
string out_object;
|
||||||
string out_header;
|
string out_header;
|
||||||
|
string out_session_module;
|
||||||
|
|
||||||
// C++ codegen options
|
// C++ codegen options
|
||||||
bool gen_name_to_index = false;
|
bool gen_name_to_index = false;
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,7 @@ def tf_library(name, graph, config,
|
||||||
# Rule that runs tfcompile to produce the header and object file.
|
# Rule that runs tfcompile to produce the header and object file.
|
||||||
header_file = name + ".h"
|
header_file = name + ".h"
|
||||||
object_file = name + ".o"
|
object_file = name + ".o"
|
||||||
|
session_module_pb = name + "_session_module.pb"
|
||||||
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
|
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name=("gen_" + name),
|
name=("gen_" + name),
|
||||||
|
|
@ -139,6 +140,7 @@ def tf_library(name, graph, config,
|
||||||
outs=[
|
outs=[
|
||||||
header_file,
|
header_file,
|
||||||
object_file,
|
object_file,
|
||||||
|
session_module_pb,
|
||||||
],
|
],
|
||||||
cmd=("$(location " + tfcompile_tool + ")" +
|
cmd=("$(location " + tfcompile_tool + ")" +
|
||||||
" --graph=$(location " + tfcompile_graph + ")" +
|
" --graph=$(location " + tfcompile_graph + ")" +
|
||||||
|
|
@ -148,6 +150,7 @@ def tf_library(name, graph, config,
|
||||||
" --target_triple=" + target_llvm_triple() +
|
" --target_triple=" + target_llvm_triple() +
|
||||||
" --out_header=$(@D)/" + header_file +
|
" --out_header=$(@D)/" + header_file +
|
||||||
" --out_object=$(@D)/" + object_file +
|
" --out_object=$(@D)/" + object_file +
|
||||||
|
" --out_session_module=$(@D)/" + session_module_pb +
|
||||||
" " + (tfcompile_flags or "")),
|
" " + (tfcompile_flags or "")),
|
||||||
tools=[tfcompile_tool],
|
tools=[tfcompile_tool],
|
||||||
visibility=visibility,
|
visibility=visibility,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user