mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Infer output shape for restore op.
PiperOrigin-RevId: 163762216
This commit is contained in:
parent
2e2a8536d7
commit
8b1365bb40
|
|
@ -207,6 +207,35 @@ Status GraphProperties::InferStatically() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Infer output shape for Restore op.
|
||||
if (node->op_def().name() == "Restore") {
|
||||
// TODO(yuefengz): deal with RestoreSlice and RestoreV2 ops.
|
||||
auto ctx = shape_refiner.GetContext(node);
|
||||
int output_idx = 0;
|
||||
for (const Node* output : node->out_nodes()) {
|
||||
if (!ctx->FullyDefined(ctx->output(output_idx)) &&
|
||||
output->op_def().name() == "Assign") {
|
||||
if (!output->attrs().Find("validate_shape") ||
|
||||
!output->attrs().Find("validate_shape")->b()) {
|
||||
continue;
|
||||
}
|
||||
auto output_ctx = shape_refiner.GetContext(output);
|
||||
if (output_ctx->FullyDefined(output_ctx->output(0))) {
|
||||
ctx->set_output(output_idx, output_ctx->output(0));
|
||||
} else {
|
||||
const Node* var;
|
||||
TF_CHECK_OK(node->input_node(0, &var));
|
||||
if (node->IsVariable()) {
|
||||
auto var_ctx = shape_refiner.GetContext(var);
|
||||
CHECK(var_ctx->FullyDefined(var_ctx->output(0)));
|
||||
ctx->set_output(output_idx, var_ctx->output(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
++output_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate the initial shapes of Enter nodes manually (the Enter shape
|
||||
|
|
|
|||
|
|
@ -550,6 +550,31 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) {
|
|||
EXPECT_EQ("float: [-1,4]", PropToString(prop));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
|
||||
DataType::DT_FLOAT);
|
||||
Output filename =
|
||||
ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
|
||||
Output tensor_name =
|
||||
ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
|
||||
Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
|
||||
DataType::DT_FLOAT);
|
||||
Output init = ops::Assign(s.WithOpName("init"), var, restore);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch.push_back("init");
|
||||
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically());
|
||||
|
||||
const auto props = properties.GetOutputProperties("restore");
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ("float: [128,256]", PropToString(prop));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user