[dynamo] Guard serialization for DICT_KEYS_MATCH (#152723)

DICT_KEYS_MATCH

Differential Revision: [D74091886](https://our.internmc.facebook.com/intern/diff/D74091886/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152723
Approved by: https://github.com/jansel
ghstack dependencies: #152615, #152616, #152687, #152716, #152721
This commit is contained in:
zhxchen17 2025-05-02 12:25:30 -07:00 committed by PyTorch MergeBot
parent 2da9ab4b1c
commit fd6d4a6a24

View File

@ -495,6 +495,24 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
ref, loaded, {"mp": types.MappingProxyType({"a": torch.randn(3, 2)})}, False
)
def test_dict_keys_match(self):
def fn(x):
ret = 1
for k in x:
ret += x[k]
return ret
x = {"a": torch.randn(3, 2), "b": torch.randn(3, 2)}
ref, loaded = self._test_serialization("DICT_KEYS_MATCH", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
self._test_check_fn(
ref,
loaded,
{"x": {"b": torch.randn(3, 2), "a": torch.randn(3, 2)}},
False,
)
self._test_check_fn(ref, loaded, {"x": {"a": torch.randn(3, 2)}}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests