diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 7197c5ac773..67b1aa46b1c 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -621,6 +621,22 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): ): self._test_serialization("NN_MODULE", fn, m, x) + def test_function_match(self): + def fn(x): + # usage of this context manager installs a FUNCTION_MATCH guard + with torch.no_grad(): + y = x * 2 + return y + + x = torch.randn(3) + + # we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't + # support that in serialization + with self.assertRaisesRegex( + RuntimeError, "FUNCTION_MATCH guard cannot be serialized." + ): + self._test_serialization("FUNCTION_MATCH", fn, x) + def test_dict_version(self): def fn(x): return pytree.tree_leaves(x)[0] + 1 diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9a29f3899b3..ab5aff54829 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1761,6 +1761,9 @@ class GuardBuilder(GuardBuilderBase): def FUNCTION_MATCH(self, guard: Guard): """things like torch.add and user defined functions""" + # don't support this in serialization because it uses unsupported ID_MATCH + if self.serialization_mode == "save": + raise RuntimeError("FUNCTION_MATCH guard cannot be serialized.") return self.ID_MATCH(guard) def CLOSURE_MATCH(self, guard: Guard):