From 1b9daeb240e83b927df16e69dfb30583da600216 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:36 -0700 Subject: [PATCH] [Dynamo] Trace enter/exit of TorchFunctionModes (#135422) This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode) Typically the bytecode for a context manager looks like this during a graph break: 1. graph call 2. enter context 3. unsupported code 4. exit context 5. resume call resume fn structure: 1. enter context 2. jump ... 3. exit context The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack). So for torch function modes the structure of our output code is this: 1. graph call 2. mutate tf mode stack to replay mutations 4. unsupported code 5. on exception restore stack 6. resume function Then our resume fn looks like this: 1. no-op enter torch function mode 2. jump 3. exit tf mode To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context). Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422 Approved by: https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443, #135444 --- test/dynamo/test_modes.py | 88 ++++++++++++++++++ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ++++ torch/_dynamo/resume_execution.py | 104 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 11 +++ torch/_dynamo/symbolic_convert.py | 30 ++++-- torch/_dynamo/testing.py | 1 + torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 18 ++-- torch/_dynamo/variables/ctx_manager.py | 12 +++ torch/_dynamo/variables/torch.py | 21 ++++- torch/_dynamo/variables/torch_function.py | 108 ++++++++++++++++++++-- torch/_dynamo/variables/user_defined.py | 14 ++- 14 files changed, 408 insertions(+), 33 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320..63e06a515ed 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,6 +461,94 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): self.assertEqual(expected, actual) + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b3e4a7c268c..171af02e564 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,6 +112,7 @@ from .utils import ( troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() - exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba8cfa5a03d..0ea19e1cad6 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ from .utils import ( get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ class OutputGraph: local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ class OutputGraph: # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ class OutputGraph: prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 2b3f38920a0..5b2812bc08c 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,6 +25,26 @@ if TYPE_CHECKING: sys as sys, ) +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081b..6f7db66514c 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ import types from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fc1bd976ff5..fd5b54e4f7f 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,11 +593,22 @@ class SideEffects: elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed3..61acaec9957 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,13 +267,12 @@ class BlockStackEntry: else: return ReenterWith(self.stack_index) - def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -639,10 +638,17 @@ def break_graph_if_unsupported(*, push): cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2295,7 +2301,10 @@ class InstructionTranslatorBase( ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2668,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase): local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 4922b521bad..704a3889707 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def debug_insert_nops( local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c..ed81b26191b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,6 +304,7 @@ manual_torch_name_rule_map = { "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2797,7 +2798,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5a8a9b613b8..193e227267e 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1668,15 +1669,16 @@ class VariableBuilder: # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e7..e19c4e254c6 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker): if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef02..c1e0dec0fbc 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,7 +162,17 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - return set(chain(*get_overridable_functions_().values())) + funcs = set(chain(*get_overridable_functions_().values())) + more = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs class BaseTorchVariable(VariableTracker): @@ -838,6 +848,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): len(tx.symbolic_torch_function_state.mode_stack) ) + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -855,7 +872,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0..b7fd83a3d24 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,22 +2,35 @@ import collections import contextlib +import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -63,6 +76,39 @@ banned_attrs = [ IGNORED_MODES = {DeviceContext} +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -189,9 +235,26 @@ class TorchFunctionModeStackVariable(VariableTracker): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + def __init__(self, value, source=None, **kwargs): - super().__init__(value, **kwargs) + if value is not None: + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -221,8 +284,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable): kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70..14499b4d2e4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,10 +409,22 @@ class UserDefinedClassVariable(UserDefinedVariable): and self.source and not is_forbidden_context_manager(self.value) ): + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj