[dynamo] Guard serialization for NONE_MATCH. (#152329)

Differential Revision: [D73780435](https://our.internmc.facebook.com/intern/diff/D73780435/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152329
Approved by: https://github.com/jansel
ghstack dependencies: #152325, #152326, #152327, #152328
This commit is contained in:
zhxchen17 2025-04-29 07:56:17 -07:00 committed by PyTorch MergeBot
parent ab4091a9fa
commit 0b39124ea3

View File

@ -257,6 +257,19 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, False)
def test_none_match(self):
def fn(x, b):
if b is None:
return x + 1
else:
return x + 2
ref, loaded = self._test_serialization("NONE_MATCH", fn, torch.randn(3), None)
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, True)
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests