Fix the shape information propagation for Enter op.

PiperOrigin-RevId: 165653579
This commit is contained in:
A. Unique TensorFlower 2017-08-17 17:44:32 -07:00 committed by TensorFlower Gardener
parent 641943fd71
commit 465c408196
2 changed files with 20 additions and 0 deletions

View File

@ -204,6 +204,13 @@ REGISTER_OP("Enter")
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
} else {
// Otherwise, propagate shape if output is a constant.
bool is_constant;
TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
if (is_constant) {
c->set_output(0, c->input(0));
}
}
return Status::OK();

View File

@ -179,6 +179,19 @@ class ControlFlowTest(test.TestCase):
result = exit_op.eval()
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
def testEnterShapePropagation(self):
with self.test_session():
v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
# If is_constant=True, the shape information should be propagated.
enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True)
self.assertEqual(enter_v_constant.shape, [2])
# Otherwise, the shape should be unknown.
enter_v_non_constant = control_flow_ops.enter(v, "frame2",
is_constant=False)
self.assertEqual(enter_v_non_constant.shape, None)
def testSwitchMergeIndexedSlices(self):
with self.test_session():
values = constant_op.constant([1, 2, 3, 4, 5, 6])