diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 6add972904d..313a07f4f18 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -655,6 +655,39 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): ): self._test_serialization("CLOSURE_MATCH", fn, x) + def test_sequence_length(self): + # tuple input installs a SEQUENCE_LENGTH guard + def fn(t, x): + return t[1] + x + + t = tuple(torch.randn(3) for _ in range(3)) + x = torch.randn(3) + + ref, loaded = self._test_serialization("SEQUENCE_LENGTH", fn, t, x) + self._test_check_fn(ref, loaded, {"x": x, "t": t}, True) + self._test_check_fn( + ref, + loaded, + { + "x": torch.randn(3), + "t": tuple(torch.randn(3) for _ in range(3)), + }, + True, + ) + # different types in tuple of same length shouldn't fail SEQUENCE_LENGTH guard + # (it should fail the separate TYPE_MATCH guard but that isn't tested here) + self._test_check_fn(ref, loaded, {"x": torch.randn(3), "t": (0, 1, 2)}, True) + # different length tuple + self._test_check_fn( + ref, + loaded, + { + "x": torch.randn(3), + "t": tuple(torch.randn(3) for _ in range(4)), + }, + False, + ) + def test_dict_version(self): def fn(x): return pytree.tree_leaves(x)[0] + 1