mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Guard serialization for SEQUENCE_LENGTH (#152730)
Tests only; no other changes needed. Test logic uses a tuple function input to trigger installation of a SEQUENCE_LENGTH guard. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152730 Approved by: https://github.com/jansel ghstack dependencies: #152725, #152727, #152728
This commit is contained in:
parent
42954ab28e
commit
fb500d0b1c
|
|
@ -655,6 +655,39 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
):
|
||||
self._test_serialization("CLOSURE_MATCH", fn, x)
|
||||
|
||||
def test_sequence_length(self):
|
||||
# tuple input installs a SEQUENCE_LENGTH guard
|
||||
def fn(t, x):
|
||||
return t[1] + x
|
||||
|
||||
t = tuple(torch.randn(3) for _ in range(3))
|
||||
x = torch.randn(3)
|
||||
|
||||
ref, loaded = self._test_serialization("SEQUENCE_LENGTH", fn, t, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x, "t": t}, True)
|
||||
self._test_check_fn(
|
||||
ref,
|
||||
loaded,
|
||||
{
|
||||
"x": torch.randn(3),
|
||||
"t": tuple(torch.randn(3) for _ in range(3)),
|
||||
},
|
||||
True,
|
||||
)
|
||||
# different types in tuple of same length shouldn't fail SEQUENCE_LENGTH guard
|
||||
# (it should fail the separate TYPE_MATCH guard but that isn't tested here)
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "t": (0, 1, 2)}, True)
|
||||
# different length tuple
|
||||
self._test_check_fn(
|
||||
ref,
|
||||
loaded,
|
||||
{
|
||||
"x": torch.randn(3),
|
||||
"t": tuple(torch.randn(3) for _ in range(4)),
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
def test_dict_version(self):
|
||||
def fn(x):
|
||||
return pytree.tree_leaves(x)[0] + 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user