mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix the shape information propagation for Enter op.
PiperOrigin-RevId: 165653579
This commit is contained in:
parent
641943fd71
commit
465c408196
|
|
@ -204,6 +204,13 @@ REGISTER_OP("Enter")
|
|||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
} else {
|
||||
// Otherwise, propagate shape if output is a constant.
|
||||
bool is_constant;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
|
||||
if (is_constant) {
|
||||
c->set_output(0, c->input(0));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -179,6 +179,19 @@ class ControlFlowTest(test.TestCase):
|
|||
result = exit_op.eval()
|
||||
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
|
||||
|
||||
def testEnterShapePropagation(self):
|
||||
with self.test_session():
|
||||
v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
|
||||
|
||||
# If is_constant=True, the shape information should be propagated.
|
||||
enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True)
|
||||
self.assertEqual(enter_v_constant.shape, [2])
|
||||
|
||||
# Otherwise, the shape should be unknown.
|
||||
enter_v_non_constant = control_flow_ops.enter(v, "frame2",
|
||||
is_constant=False)
|
||||
self.assertEqual(enter_v_non_constant.shape, None)
|
||||
|
||||
def testSwitchMergeIndexedSlices(self):
|
||||
with self.test_session():
|
||||
values = constant_op.constant([1, 2, 3, 4, 5, 6])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user