[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)

This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This commit is contained in:
Michael Lazos 2024-09-14 03:23:36 -07:00 committed by PyTorch MergeBot
parent 06caa2d560
commit 1b9daeb240
14 changed files with 408 additions and 33 deletions

View File

@ -461,6 +461,94 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -112,6 +112,7 @@ from .utils import (
troubleshooting_url,
write_record_to_file,
)
from .variables.torch_function import torch_function_mode_stack_state_mgr
np: Optional[ModuleType]
@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)

View File

@ -78,7 +78,6 @@ from .utils import (
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@ -250,6 +249,7 @@ class OutputGraph:
local_scope: Scope,
global_scope: Scope,
f_code,
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
@ -368,7 +368,7 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = get_torch_function_mode_stack()
self.torch_function_mode_stack = torch_function_mode_stack
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
@ -1020,7 +1020,7 @@ class OutputGraph:
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx)
block.exit(tx, is_graph_break=reason.graph_break)
self.cleanup_graph()
tx.prune_dead_locals()

View File

@ -25,6 +25,26 @@ if TYPE_CHECKING:
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice

View File

@ -6,6 +6,7 @@ import types
from typing import Any, cast, Dict, List, Optional, Tuple
from .bytecode_transformation import (
add_push_null,
create_call_function,
create_call_method,
create_dup_top,
@ -48,6 +49,109 @@ class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
try:
(rest)
except:
(restore previous stack)
"""
from .variables.torch_function import get_prev_stack_var_name
except_jump_target = create_instruction(
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
)
cleanup_complete_jump_target = create_instruction("NOP")
setup_finally: List[Instruction] = []
if sys.version_info < (3, 11):
setup_finally.append(
create_instruction("SETUP_FINALLY", target=except_jump_target)
)
else:
exn_tab_begin = create_instruction("NOP")
exn_tab_end = create_instruction("NOP")
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
exn_tab_begin,
exn_tab_end,
except_jump_target,
self.stack_index + 1,
False,
)
setup_finally.append(exn_tab_begin)
def create_reset():
insts = [
create_instruction(
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
),
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
]
add_push_null(insts)
return [
*insts,
create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()),
*create_call_function(1, False),
create_instruction("POP_TOP"),
]
if sys.version_info < (3, 9):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
*create_reset(),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
create_instruction("END_FINALLY"),
cleanup_complete_jump_target,
]
elif sys.version_info < (3, 11):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
cleanup_complete_jump_target,
]
else:
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
finally_exn_tab_target = create_instruction("COPY", arg=3)
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
except_jump_target,
finally_exn_tab_end,
finally_exn_tab_target,
self.stack_index + 2,
True,
)
epilogue = [
exn_tab_end,
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target, # PUSH_EXC_INFO
create_instruction("POP_TOP"),
*create_reset(),
finally_exn_tab_end,
finally_exn_tab_target, # COPY 3
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", arg=1), # RERAISE 1
cleanup_complete_jump_target,
]
cleanup[:] = epilogue + cleanup
return setup_finally
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
def try_except(self, code_options, cleanup: List[Instruction]):

View File

@ -593,11 +593,22 @@ class SideEffects:
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)
cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))

View File

@ -267,13 +267,12 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
if hasattr(self, "graph_break") and isinstance(
self.with_context, TorchFunctionModeVariable
):
return
def exit(self, tx, is_graph_break):
assert self.with_context is not None
return self.with_context.exit(tx)
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
class ReturnValueOp(Exception):
@ -639,10 +638,17 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
cg.extend_output(
b.resume_fn().try_except_torch_function_mode(
cg.code_options, cleanup
)
)
continue
assert b.with_context is not None
assert isinstance(
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
)
assert isinstance(b.with_context, (ContextWrappingVariable))
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -2295,7 +2301,10 @@ class InstructionTranslatorBase(
):
unimplemented(f"{inst.opname} {ctx}")
if isinstance(ctx, GenericContextWrappingVariable):
if (
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
self.generic_context_manager_depth += 1
# Need this redundant check for mypy
@ -2668,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase):
local_scope=f_locals,
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
),
instructions=instructions,
f_locals=f_locals,

View File

@ -163,6 +163,7 @@ def debug_insert_nops(
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))

View File

@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
"torch.set_default_device": UserFunctionVariable,
"torch.sparse_bsc_tensor": SkipFunctionVariable,
"torch.sparse_bsr_tensor": SkipFunctionVariable,
"torch.sparse_csc_tensor": SkipFunctionVariable,
@ -2797,7 +2798,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.random.initial_seed",
"torch.random.seed",
"torch.return_types.pytree_register_structseq",
"torch.set_default_device",
"torch.set_default_dtype",
"torch.set_default_tensor_type",
"torch.set_deterministic_debug_mode",

View File

@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import (
build_torch_function_fn,
TensorWithTFOverrideVariable,
torch_function_mode_stack_state_mgr,
TorchFunctionModeVariable,
)
from .user_defined import (
@ -1668,15 +1669,16 @@ class VariableBuilder:
# but warning is not the end of the world
assert isinstance(value.base, np.nditer)
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
with torch_function_mode_stack_state_mgr.temp_restore_stack():
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here

View File

@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""

View File

@ -162,7 +162,17 @@ def get_overridable_functions():
from torch.overrides import get_overridable_functions as get_overridable_functions_
return set(chain(*get_overridable_functions_().values()))
funcs = set(chain(*get_overridable_functions_().values()))
more = {
torch.ones,
torch.ones_like,
torch.zeros,
torch.zeros_like,
torch.empty,
torch.full,
}
funcs.update(more)
return funcs
class BaseTorchVariable(VariableTracker):
@ -838,6 +848,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs
@ -855,7 +872,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else:
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
return None
return ConstantVariable.create(None)
return handlers

View File

@ -2,22 +2,35 @@
import collections
import contextlib
import functools
import inspect
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .ctx_manager import GenericContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -63,6 +76,39 @@ banned_attrs = [
IGNORED_MODES = {DeviceContext}
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
@ -189,9 +235,26 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(ContextWrappingVariable):
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
super().__init__(value, **kwargs)
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
@ -221,8 +284,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable):
kwargs,
)
def _call_func(self, tx: "InstructionTranslator", values):
unimplemented("enter/exit for torch function mode NYI")
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):

View File

@ -409,10 +409,22 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, GenericContextWrappingVariable, {}
self.source, self.value, var_cls, {}
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj