[Grappler] Move InferOutputShapes to GraphProperties.

So it can be used by other optimizers. No functional changes.

PiperOrigin-RevId: 171010106
This commit is contained in:
Jingyue Wu 2017-10-04 08:04:48 -07:00 committed by TensorFlower Gardener
parent 2114fd51e9
commit 7db7a890c0
4 changed files with 24 additions and 20 deletions

View File

@ -455,6 +455,20 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
return InferFromCostGraph(metadata.cost_graph());
}
Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) {
*output_graph_def = item_.graph;
for (int i = 0; i < output_graph_def->node_size(); i++) {
auto node = output_graph_def->mutable_node(i);
AttrValue attr_output_shape;
auto tensor_properties = GetOutputProperties(node->name());
for (const auto& tensor_property : tensor_properties) {
*attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
}
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
}
return Status::OK();
}
Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
std::unordered_map<string, const NodeDef*> name_to_node; // Empty

View File

@ -39,6 +39,9 @@ class GraphProperties {
Status InferDynamically(Cluster* cluster);
Status InferFromCostGraph(const CostGraphDef& cost_graph);
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
Status AnnotateOutputShapes(GraphDef* output_graph_def);
bool HasInputProperties(const string& name) const;
bool HasOutputProperties(const string& name) const;
const std::vector<OpInfo::TensorProperties>& GetInputProperties(

View File

@ -1385,21 +1385,6 @@ int GetNumTranspose(const GraphDef& graph) {
return number;
}
Status LayoutOptimizer::InferOutputShapes(GrapplerItem* item) {
GraphProperties graph_properties(*item);
TF_RETURN_IF_ERROR(graph_properties.InferStatically());
for (int i = 0; i < item->graph.node_size(); i++) {
auto node = item->graph.mutable_node(i);
AttrValue attr_output_shape;
auto tensor_properties = graph_properties.GetOutputProperties(node->name());
for (const auto& tensor_property : tensor_properties) {
*attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
}
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
}
return Status::OK();
}
Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
if (num_gpus_ == 0) {
@ -1411,14 +1396,18 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
GrapplerItem new_item = item;
auto status = InferOutputShapes(&new_item);
GraphProperties graph_properties(item);
auto status = graph_properties.InferStatically();
if (!status.ok()) {
*output = item.graph;
return status;
}
status = graph_properties.AnnotateOutputShapes(output);
if (!status.ok()) {
*output = item.graph;
return status;
}
*output = new_item.graph;
TuningConfig config;
config.no_gemm = false;
string default_device = "/job:localhost/replica:0/task:0/cpu:0";
@ -1435,7 +1424,6 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// nodes is more than 30, not using GEMM implementation would result in better
// performance.
if (status.ok() && GetNumTranspose(*output) > 30) {
*output = new_item.graph;
config.no_gemm = true;
node_map.reset(new NodeMap(output));
layout_optimizer.reset(new DataLayoutOptimizer(default_device, output,

View File

@ -39,7 +39,6 @@ class LayoutOptimizer : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
Status InferOutputShapes(GrapplerItem* item);
int num_gpus_ = 0;
};