mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-08 07:38:39 +01:00
Fix discrepancy between measured and analytical cost graph. Use tf_cuda_library for utils.
PiperOrigin-RevId: 157660745
This commit is contained in:
parent
787381ca52
commit
e5088cb823
|
|
@ -1,5 +1,7 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
|
@ -20,7 +22,7 @@ config_setting(
|
|||
},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
tf_cuda_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
hdrs = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
|
@ -108,25 +110,21 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
tf_cuda_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
hdrs = ["utils.h"],
|
||||
defines = if_cuda(["GOOGLE_CUDA=1"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":op_performance_data_cc",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_cuda([
|
||||
"//tensorflow/core:cuda",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]),
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
|
|
@ -167,6 +167,9 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
|||
inputs.push_back(UnknownInput());
|
||||
} else {
|
||||
const CostGraphDef::Node* input_cost = it->second;
|
||||
if (input_cost->output_info_size() == 0) {
|
||||
inputs.push_back(UnknownInput());
|
||||
} else {
|
||||
const CostGraphDef::Node::OutputInfo& output =
|
||||
input_cost->output_info(output_index);
|
||||
OpInfo::TensorProperties input;
|
||||
|
|
@ -175,6 +178,7 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
|||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,8 +56,8 @@ void CostAnalyzer::GatherCosts() {
|
|||
CostGraphDef cost_graph_measured;
|
||||
PredictCosts(&measure_estimator_, &cost_graph_measured,
|
||||
&total_time_measured_);
|
||||
VLOG(1) << "Graph size: " << item_->graph.node_size();
|
||||
VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
|
||||
op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph);
|
||||
|
||||
CostGraphDef cost_graph_analytical;
|
||||
PredictCosts(&analytical_estimator_, &cost_graph_analytical,
|
||||
|
|
@ -66,25 +66,32 @@ void CostAnalyzer::GatherCosts() {
|
|||
<< cost_graph_analytical.node_size();
|
||||
|
||||
CostGraphDef cost_graph_analytical_filtered;
|
||||
std::set<string> cost_nodes;
|
||||
for (auto& node : cost_graph_measured.node()) {
|
||||
cost_nodes.insert(node.name());
|
||||
CostGraphDef cost_graph_measured_filtered;
|
||||
std::map<string, const CostGraphDef_Node*> measured_nodes;
|
||||
for (const auto& node : cost_graph_measured.node()) {
|
||||
measured_nodes[node.name()] = &node;
|
||||
}
|
||||
for (const auto& node : cost_graph_analytical.node()) {
|
||||
auto it = cost_nodes.find(node.name());
|
||||
auto it = measured_nodes.find(node.name());
|
||||
// Filter the nodes that are not the cost nodes returned by
|
||||
// MeasuringCostEstimator.
|
||||
if (it == cost_nodes.end()) {
|
||||
if (it == measured_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
auto added_node = cost_graph_analytical_filtered.add_node();
|
||||
*added_node = node;
|
||||
auto added_node_analytical = cost_graph_analytical_filtered.add_node();
|
||||
auto added_node_measured = cost_graph_measured_filtered.add_node();
|
||||
*added_node_analytical = node;
|
||||
*added_node_measured = *(it->second);
|
||||
}
|
||||
VLOG(1) << "cost_graph_analytical_filtered size: "
|
||||
<< cost_graph_analytical_filtered.node_size();
|
||||
|
||||
// TODO(yaozhang): add a test to make sure that op_perf_analytical_ and
|
||||
// op_perf_ cover the same set of nodes.
|
||||
op_perf_analytical_ = CostGraphToOpPerformanceData(
|
||||
cost_graph_analytical_filtered, item_->graph);
|
||||
op_perf_ =
|
||||
CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph);
|
||||
}
|
||||
|
||||
void CostAnalyzer::PreprocessCosts() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user