diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 96b61f62800..674bd466f5a 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -495,6 +495,24 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): ref, loaded, {"mp": types.MappingProxyType({"a": torch.randn(3, 2)})}, False ) + def test_dict_keys_match(self): + def fn(x): + ret = 1 + for k in x: + ret += x[k] + return ret + + x = {"a": torch.randn(3, 2), "b": torch.randn(3, 2)} + ref, loaded = self._test_serialization("DICT_KEYS_MATCH", fn, x) + self._test_check_fn(ref, loaded, {"x": x}, True) + self._test_check_fn( + ref, + loaded, + {"x": {"b": torch.randn(3, 2), "a": torch.randn(3, 2)}}, + False, + ) + self._test_check_fn(ref, loaded, {"x": {"a": torch.randn(3, 2)}}, False) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests