mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[Grappler] Move InferOutputShapes to GraphProperties.
So it can be used by other optimizers. No functional changes. PiperOrigin-RevId: 171010106
This commit is contained in:
parent
2114fd51e9
commit
7db7a890c0
|
|
@ -455,6 +455,20 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
|
|||
return InferFromCostGraph(metadata.cost_graph());
|
||||
}
|
||||
|
||||
Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) {
|
||||
*output_graph_def = item_.graph;
|
||||
for (int i = 0; i < output_graph_def->node_size(); i++) {
|
||||
auto node = output_graph_def->mutable_node(i);
|
||||
AttrValue attr_output_shape;
|
||||
auto tensor_properties = GetOutputProperties(node->name());
|
||||
for (const auto& tensor_property : tensor_properties) {
|
||||
*attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
|
||||
}
|
||||
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
|
||||
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
|
||||
std::unordered_map<string, const NodeDef*> name_to_node; // Empty
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ class GraphProperties {
|
|||
Status InferDynamically(Cluster* cluster);
|
||||
Status InferFromCostGraph(const CostGraphDef& cost_graph);
|
||||
|
||||
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
|
||||
Status AnnotateOutputShapes(GraphDef* output_graph_def);
|
||||
|
||||
bool HasInputProperties(const string& name) const;
|
||||
bool HasOutputProperties(const string& name) const;
|
||||
const std::vector<OpInfo::TensorProperties>& GetInputProperties(
|
||||
|
|
|
|||
|
|
@ -1385,21 +1385,6 @@ int GetNumTranspose(const GraphDef& graph) {
|
|||
return number;
|
||||
}
|
||||
|
||||
Status LayoutOptimizer::InferOutputShapes(GrapplerItem* item) {
|
||||
GraphProperties graph_properties(*item);
|
||||
TF_RETURN_IF_ERROR(graph_properties.InferStatically());
|
||||
for (int i = 0; i < item->graph.node_size(); i++) {
|
||||
auto node = item->graph.mutable_node(i);
|
||||
AttrValue attr_output_shape;
|
||||
auto tensor_properties = graph_properties.GetOutputProperties(node->name());
|
||||
for (const auto& tensor_property : tensor_properties) {
|
||||
*attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
|
||||
}
|
||||
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
if (num_gpus_ == 0) {
|
||||
|
|
@ -1411,14 +1396,18 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
GrapplerItem new_item = item;
|
||||
auto status = InferOutputShapes(&new_item);
|
||||
GraphProperties graph_properties(item);
|
||||
auto status = graph_properties.InferStatically();
|
||||
if (!status.ok()) {
|
||||
*output = item.graph;
|
||||
return status;
|
||||
}
|
||||
status = graph_properties.AnnotateOutputShapes(output);
|
||||
if (!status.ok()) {
|
||||
*output = item.graph;
|
||||
return status;
|
||||
}
|
||||
|
||||
*output = new_item.graph;
|
||||
TuningConfig config;
|
||||
config.no_gemm = false;
|
||||
string default_device = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
|
|
@ -1435,7 +1424,6 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||
// nodes is more than 30, not using GEMM implementation would result in better
|
||||
// performance.
|
||||
if (status.ok() && GetNumTranspose(*output) > 30) {
|
||||
*output = new_item.graph;
|
||||
config.no_gemm = true;
|
||||
node_map.reset(new NodeMap(output));
|
||||
layout_optimizer.reset(new DataLayoutOptimizer(default_device, output,
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ class LayoutOptimizer : public GraphOptimizer {
|
|||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
Status InferOutputShapes(GrapplerItem* item);
|
||||
int num_gpus_ = 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user