[Dynamo] Guard serialization for SEQUENCE_LENGTH (#152730)

Tests only; no other changes needed. Test logic uses a tuple function input to trigger installation of a SEQUENCE_LENGTH guard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152730
Approved by: https://github.com/jansel
ghstack dependencies: #152725, #152727, #152728
This commit is contained in:
Joel Schlosser 2025-05-07 10:57:50 -04:00 committed by PyTorch MergeBot
parent 42954ab28e
commit fb500d0b1c

View File

@ -655,6 +655,39 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
):
self._test_serialization("CLOSURE_MATCH", fn, x)
def test_sequence_length(self):
# tuple input installs a SEQUENCE_LENGTH guard
def fn(t, x):
return t[1] + x
t = tuple(torch.randn(3) for _ in range(3))
x = torch.randn(3)
ref, loaded = self._test_serialization("SEQUENCE_LENGTH", fn, t, x)
self._test_check_fn(ref, loaded, {"x": x, "t": t}, True)
self._test_check_fn(
ref,
loaded,
{
"x": torch.randn(3),
"t": tuple(torch.randn(3) for _ in range(3)),
},
True,
)
# different types in tuple of same length shouldn't fail SEQUENCE_LENGTH guard
# (it should fail the separate TYPE_MATCH guard but that isn't tested here)
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "t": (0, 1, 2)}, True)
# different length tuple
self._test_check_fn(
ref,
loaded,
{
"x": torch.randn(3),
"t": tuple(torch.randn(3) for _ in range(4)),
},
False,
)
def test_dict_version(self):
def fn(x):
return pytree.tree_leaves(x)[0] + 1