diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a4da77b4c98..9c23c3e4a49 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -8556,64 +8556,15 @@ utils_device.CURRENT_DEVICE == None""".split("\n"): self.assertEqual(seen_frames[1].name, "uwu_inline_me") self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") - def test_recompile_on_disable_1(self): - # fix https://github.com/pytorch/pytorch/issues/157399 + def test_error_on_recompile(self): @torch.compile(backend="eager") - def fn(x): - @torch._dynamo.disable - def inner(x): - return x + 10 - - return inner(x) + 1 - - with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): - try: - for i in range(5): - fn(torch.rand(2, 3)) - except torch._dynamo.exc.RecompileError as e: - self.fail("RecompileError raised unexpectedly: " + str(e)) - - def test_recompile_on_disable_2(self): - def outer(x, cond): - @torch._dynamo.disable() - def fn0(y): - return y + 1 - - @torch._dynamo.disable() - def fn1(y): - return y + 2 - - if cond: - f = fn0 - else: - f = fn1 - - torch._dynamo.graph_break() - # there will be a resume function here - return f(x) + def fn(a, b): + return a + b with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): with self.assertRaises(torch._dynamo.exc.RecompileError): - x = torch.rand(2, 3) - self.assertEqual(outer(x, True), torch.compile(outer)(x, True)) - self.assertEqual(outer(x, False), torch.compile(outer)(x, False)) - - def test_create_nested_fn_cache_clear(self): - def outer(x): - @torch._dynamo.disable() - def f(y): - return y + 2 - - return f(x) + 1 - - outer = torch.compile(outer) - with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): - with self.assertRaises(torch._dynamo.exc.RecompileError): - outer(torch.randn(3, 3)) - from torch._dynamo.utils import create_nested_fn_cache - - create_nested_fn_cache.clear() - outer(torch.randn(3, 3)) + fn(torch.rand(2, 3), torch.rand(2, 3)) + fn(torch.rand(2, 3), (1, 2, 3)) def test_guards_strip_function_call(self): from torch._dynamo.guards import strip_function_call diff --git a/test/test_autograd.py b/test/test_autograd.py index 5d7f81eeb4f..01929a276f5 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -610,8 +610,6 @@ class TestAutograd(TestCase): with disable_gc(): unpack_hook_ref = scope() - if torch._dynamo.is_compiling(): - torch._dynamo.reset() self.assertIsNone(unpack_hook_ref()) def test_will_engine_execute_node(self): diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 59c11803bb9..02b921b30ee 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -51,7 +51,6 @@ from .mutation_guard import GenerationTracker from .pgo import reset_code_state from .symbolic_convert import TensorifyState from .utils import ( - create_nested_fn_cache, graph_break_reasons, guard_failures, orig_code_map, @@ -145,7 +144,6 @@ def reset() -> None: torch._dynamo.utils.warn_once_cache.clear() torch._dynamo.utils.user_obj_id_to_weakref.clear() torch._C._autograd._saved_tensors_hooks_set_tracing(False) - create_nested_fn_cache.clear() def reset_code_caches() -> None: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fc42934d98d..575fe901fc1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4771,22 +4771,3 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() - - -class CreateNestedFnCache: - cache: dict[str, types.FunctionType] = {} - - @classmethod - def get(cls, key): - return cls.cache.get(key, None) - - @classmethod - def set(cls, key, value): - cls.cache[key] = value - - @classmethod - def clear(cls): - cls.cache.clear() - - -create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache() diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1cf015d1076..31dbec48401 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -62,7 +62,6 @@ from ..utils import ( check_unspec_or_constant_args, cmp_name_to_op_mapping, counters, - create_nested_fn_cache, identity, is_function, is_wrapper_or_member_descriptor, @@ -270,11 +269,6 @@ def _create_nested_fn( ): from types import FunctionType - # Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard. - cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals)) - if create_nested_fn_cache.get(cache_key): - return create_nested_fn_cache.get(cache_key) - func = FunctionType(code, f_globals, name, defaults, closure) func.__kwdefaults__ = kwdefaults @@ -286,7 +280,7 @@ def _create_nested_fn( # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) func.__annotations__ = annotations - create_nested_fn_cache.set(cache_key, func) + return func @@ -1433,13 +1427,7 @@ class SkipFunctionVariable(VariableTracker): @classmethod def create_with_source(cls, value, source): - if inspect.getattr_static(value, "_torchdynamo_orig_callable", False): - install_guard( - AttrSource(source, "_torchdynamo_orig_callable").make_guard( - GuardBuilder.FUNCTION_MATCH - ) - ) - elif not is_wrapper_or_member_descriptor(value): + if not is_wrapper_or_member_descriptor(value): # These descriptors are not guaranteed to return the same object on # attribute lookup. They are unlikely to be changed, so we can skip # guarding them.