mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for DICT_KEYS_MATCH (#152723)
DICT_KEYS_MATCH Differential Revision: [D74091886](https://our.internmc.facebook.com/intern/diff/D74091886/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152723 Approved by: https://github.com/jansel ghstack dependencies: #152615, #152616, #152687, #152716, #152721
This commit is contained in:
parent
2da9ab4b1c
commit
fd6d4a6a24
|
|
@ -495,6 +495,24 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
ref, loaded, {"mp": types.MappingProxyType({"a": torch.randn(3, 2)})}, False
|
||||
)
|
||||
|
||||
def test_dict_keys_match(self):
|
||||
def fn(x):
|
||||
ret = 1
|
||||
for k in x:
|
||||
ret += x[k]
|
||||
return ret
|
||||
|
||||
x = {"a": torch.randn(3, 2), "b": torch.randn(3, 2)}
|
||||
ref, loaded = self._test_serialization("DICT_KEYS_MATCH", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
self._test_check_fn(
|
||||
ref,
|
||||
loaded,
|
||||
{"x": {"b": torch.randn(3, 2), "a": torch.randn(3, 2)}},
|
||||
False,
|
||||
)
|
||||
self._test_check_fn(ref, loaded, {"x": {"a": torch.randn(3, 2)}}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user