Infer output shape for restore op.

PiperOrigin-RevId: 163762216
This commit is contained in:
Yuefeng Zhou 2017-07-31 16:02:28 -07:00 committed by TensorFlower Gardener
parent 2e2a8536d7
commit 8b1365bb40
2 changed files with 54 additions and 0 deletions

View File

@ -207,6 +207,35 @@ Status GraphProperties::InferStatically() {
}
}
}
// Infer output shape for Restore op.
if (node->op_def().name() == "Restore") {
// TODO(yuefengz): deal with RestoreSlice and RestoreV2 ops.
auto ctx = shape_refiner.GetContext(node);
int output_idx = 0;
for (const Node* output : node->out_nodes()) {
if (!ctx->FullyDefined(ctx->output(output_idx)) &&
output->op_def().name() == "Assign") {
if (!output->attrs().Find("validate_shape") ||
!output->attrs().Find("validate_shape")->b()) {
continue;
}
auto output_ctx = shape_refiner.GetContext(output);
if (output_ctx->FullyDefined(output_ctx->output(0))) {
ctx->set_output(output_idx, output_ctx->output(0));
} else {
const Node* var;
TF_CHECK_OK(node->input_node(0, &var));
if (node->IsVariable()) {
auto var_ctx = shape_refiner.GetContext(var);
CHECK(var_ctx->FullyDefined(var_ctx->output(0)));
ctx->set_output(output_idx, var_ctx->output(0));
}
}
}
++output_idx;
}
}
}
// Propagate the initial shapes of Enter nodes manually (the Enter shape

View File

@ -550,6 +550,31 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) {
EXPECT_EQ("float: [-1,4]", PropToString(prop));
}
TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
DataType::DT_FLOAT);
Output filename =
ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
Output tensor_name =
ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
DataType::DT_FLOAT);
Output init = ops::Assign(s.WithOpName("init"), var, restore);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch.push_back("init");
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically());
const auto props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
} // namespace
} // namespace grappler
} // namespace tensorflow