mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for NONE_MATCH. (#152329)
Differential Revision: [D73780435](https://our.internmc.facebook.com/intern/diff/D73780435/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152329 Approved by: https://github.com/jansel ghstack dependencies: #152325, #152326, #152327, #152328
This commit is contained in:
parent
ab4091a9fa
commit
0b39124ea3
|
|
@ -257,6 +257,19 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, False)
|
||||
|
||||
def test_none_match(self):
|
||||
def fn(x, b):
|
||||
if b is None:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
ref, loaded = self._test_serialization("NONE_MATCH", fn, torch.randn(3), None)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, True)
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user