diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320..63e06a515ed 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,6 +461,94 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): self.assertEqual(expected, actual) + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b3e4a7c268c..171af02e564 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,6 +112,7 @@ from .utils import ( troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() - exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba8cfa5a03d..0ea19e1cad6 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ from .utils import ( get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ class OutputGraph: local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ class OutputGraph: # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ class OutputGraph: prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 2b3f38920a0..5b2812bc08c 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,6 +25,26 @@ if TYPE_CHECKING: sys as sys, ) +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081b..6f7db66514c 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ import types from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fc1bd976ff5..fd5b54e4f7f 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,11 +593,22 @@ class SideEffects: elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed3..61acaec9957 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,13 +267,12 @@ class BlockStackEntry: else: return ReenterWith(self.stack_index) - def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -639,10 +638,17 @@ def break_graph_if_unsupported(*, push): cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2295,7 +2301,10 @@ class InstructionTranslatorBase( ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2668,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase): local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 4922b521bad..704a3889707 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def debug_insert_nops( local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c..ed81b26191b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,6 +304,7 @@ manual_torch_name_rule_map = { "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2797,7 +2798,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5a8a9b613b8..193e227267e 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1668,15 +1669,16 @@ class VariableBuilder: # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e7..e19c4e254c6 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker): if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef02..c1e0dec0fbc 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,7 +162,17 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - return set(chain(*get_overridable_functions_().values())) + funcs = set(chain(*get_overridable_functions_().values())) + more = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs class BaseTorchVariable(VariableTracker): @@ -838,6 +848,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): len(tx.symbolic_torch_function_state.mode_stack) ) + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -855,7 +872,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0..b7fd83a3d24 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,22 +2,35 @@ import collections import contextlib +import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -63,6 +76,39 @@ banned_attrs = [ IGNORED_MODES = {DeviceContext} +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -189,9 +235,26 @@ class TorchFunctionModeStackVariable(VariableTracker): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + def __init__(self, value, source=None, **kwargs): - super().__init__(value, **kwargs) + if value is not None: + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -221,8 +284,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable): kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70..14499b4d2e4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,10 +409,22 @@ class UserDefinedClassVariable(UserDefinedVariable): and self.source and not is_forbidden_context_manager(self.value) ): + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj