mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode) Typically the bytecode for a context manager looks like this during a graph break: 1. graph call 2. enter context 3. unsupported code 4. exit context 5. resume call resume fn structure: 1. enter context 2. jump ... 3. exit context The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack). So for torch function modes the structure of our output code is this: 1. graph call 2. mutate tf mode stack to replay mutations 4. unsupported code 5. on exception restore stack 6. resume function Then our resume fn looks like this: 1. no-op enter torch function mode 2. jump 3. exit tf mode To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context). Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422 Approved by: https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443, #135444
This commit is contained in:
parent
06caa2d560
commit
1b9daeb240
|
|
@ -461,6 +461,94 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_enter_exit(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
torch._dynamo.graph_break()
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_restore_on_exc(self):
|
||||
@torch._dynamo.disable()
|
||||
def err():
|
||||
raise RuntimeError("test")
|
||||
|
||||
@torch.compile()
|
||||
def fn(x):
|
||||
with TestMode():
|
||||
x += 1
|
||||
err()
|
||||
x += 2
|
||||
return x
|
||||
|
||||
try:
|
||||
fn(torch.ones(2, 2))
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertEqual(_len_torch_function_stack(), 0)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
z.y = 5
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
o = torch.mul(o, z.y)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ from .utils import (
|
|||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
)
|
||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
||||
|
||||
|
||||
np: Optional[ModuleType]
|
||||
|
|
@ -210,15 +211,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
|
||||
exit_stack = contextlib.ExitStack()
|
||||
exit_stack.enter_context(
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
cleanup.close()
|
||||
assert (
|
||||
torch._C._len_torch_function_stack() == 0
|
||||
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
|
||||
exit_stack.close()
|
||||
torch._C._set_grad_enabled(prior_grad_mode)
|
||||
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ from .utils import (
|
|||
get_instruction_source_311,
|
||||
get_locals_to_steal,
|
||||
get_static_address_type,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
lazy_format_graph_code,
|
||||
|
|
@ -250,6 +249,7 @@ class OutputGraph:
|
|||
local_scope: Scope,
|
||||
global_scope: Scope,
|
||||
f_code,
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__()
|
||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||
|
|
@ -368,7 +368,7 @@ class OutputGraph:
|
|||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||
# This records the initial torch function mode stack for guarding
|
||||
self.torch_function_mode_stack = get_torch_function_mode_stack()
|
||||
self.torch_function_mode_stack = torch_function_mode_stack
|
||||
|
||||
# Tracks if the output graph has a user defined allowed function in the
|
||||
# graph. This is used later to determine if we should fallback to eager
|
||||
|
|
@ -1020,7 +1020,7 @@ class OutputGraph:
|
|||
prefix_insts.clear()
|
||||
|
||||
for block in reversed(tx.block_stack):
|
||||
block.exit(tx)
|
||||
block.exit(tx, is_graph_break=reason.graph_break)
|
||||
|
||||
self.cleanup_graph()
|
||||
tx.prune_dead_locals()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,26 @@ if TYPE_CHECKING:
|
|||
sys as sys,
|
||||
)
|
||||
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
|
||||
# These classes handle support for TorchFunctionModes across
|
||||
# graph breaks
|
||||
# Today the TorchFunctionMode enter (for the classes we support)
|
||||
# simply pushes the mode onto the stack. Since after this occurs
|
||||
# the stack is mutated, and we replay these mutations, we don't need
|
||||
# any cleanup logic to be run once the graph break occurs, we simply replay
|
||||
# these mutations to ensure at the graph break the torch function mode stack is correct
|
||||
# and reconstruct the torch function mode stack normally
|
||||
# when we compile the resume function on the other side of the break.
|
||||
# However, to ensure we exit properly
|
||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
|
||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
|
||||
# the stack state is correct.
|
||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
from itertools import islice
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import types
|
|||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
from .bytecode_transformation import (
|
||||
add_push_null,
|
||||
create_call_function,
|
||||
create_call_method,
|
||||
create_dup_top,
|
||||
|
|
@ -48,6 +49,109 @@ class ReenterWith:
|
|||
stack_index: int
|
||||
target_values: Optional[Tuple[Any, ...]] = None
|
||||
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
"""
|
||||
Codegen based off of:
|
||||
try:
|
||||
(rest)
|
||||
except:
|
||||
(restore previous stack)
|
||||
|
||||
"""
|
||||
from .variables.torch_function import get_prev_stack_var_name
|
||||
|
||||
except_jump_target = create_instruction(
|
||||
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
|
||||
)
|
||||
cleanup_complete_jump_target = create_instruction("NOP")
|
||||
|
||||
setup_finally: List[Instruction] = []
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
setup_finally.append(
|
||||
create_instruction("SETUP_FINALLY", target=except_jump_target)
|
||||
)
|
||||
else:
|
||||
exn_tab_begin = create_instruction("NOP")
|
||||
exn_tab_end = create_instruction("NOP")
|
||||
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
|
||||
exn_tab_begin,
|
||||
exn_tab_end,
|
||||
except_jump_target,
|
||||
self.stack_index + 1,
|
||||
False,
|
||||
)
|
||||
setup_finally.append(exn_tab_begin)
|
||||
|
||||
def create_reset():
|
||||
insts = [
|
||||
create_instruction(
|
||||
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
|
||||
),
|
||||
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
|
||||
]
|
||||
add_push_null(insts)
|
||||
return [
|
||||
*insts,
|
||||
create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()),
|
||||
*create_call_function(1, False),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
*create_reset(),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
create_instruction("END_FINALLY"),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
elif sys.version_info < (3, 11):
|
||||
epilogue = [
|
||||
create_instruction("POP_BLOCK"),
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target,
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
create_instruction("RAISE_VARARGS", argval=0),
|
||||
create_instruction("POP_EXCEPT", argval=0),
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
else:
|
||||
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
|
||||
finally_exn_tab_target = create_instruction("COPY", arg=3)
|
||||
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
|
||||
except_jump_target,
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target,
|
||||
self.stack_index + 2,
|
||||
True,
|
||||
)
|
||||
epilogue = [
|
||||
exn_tab_end,
|
||||
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
|
||||
except_jump_target, # PUSH_EXC_INFO
|
||||
create_instruction("POP_TOP"),
|
||||
*create_reset(),
|
||||
finally_exn_tab_end,
|
||||
finally_exn_tab_target, # COPY 3
|
||||
create_instruction("POP_EXCEPT"),
|
||||
create_instruction("RERAISE", arg=1), # RERAISE 1
|
||||
cleanup_complete_jump_target,
|
||||
]
|
||||
|
||||
cleanup[:] = epilogue + cleanup
|
||||
return setup_finally
|
||||
|
||||
# If we do not want to destroy the stack, we can do the same thing as a
|
||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||
def try_except(self, code_options, cleanup: List[Instruction]):
|
||||
|
|
|
|||
|
|
@ -593,11 +593,22 @@ class SideEffects:
|
|||
elif isinstance(
|
||||
var, variables.torch_function.TorchFunctionModeStackVariable
|
||||
):
|
||||
# Needed in the finally block for stack restoration
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "get_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
cg.call_function(0, False)
|
||||
name = variables.torch_function.get_prev_stack_var_name()
|
||||
cg.code_options["co_varnames"] += (name,)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "set_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
|
||||
cg.foreach(var.symbolic_stack)
|
||||
cg.append_output(
|
||||
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
||||
|
|
|
|||
|
|
@ -267,13 +267,12 @@ class BlockStackEntry:
|
|||
else:
|
||||
return ReenterWith(self.stack_index)
|
||||
|
||||
def exit(self, tx):
|
||||
if hasattr(self, "graph_break") and isinstance(
|
||||
self.with_context, TorchFunctionModeVariable
|
||||
):
|
||||
return
|
||||
def exit(self, tx, is_graph_break):
|
||||
assert self.with_context is not None
|
||||
return self.with_context.exit(tx)
|
||||
if (
|
||||
is_graph_break and self.with_context.exit_on_graph_break()
|
||||
) or not is_graph_break:
|
||||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
class ReturnValueOp(Exception):
|
||||
|
|
@ -639,10 +638,17 @@ def break_graph_if_unsupported(*, push):
|
|||
cleanup: List[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
# Don't exit any modes we have entered,
|
||||
# output bytecode will mutate the tf mode stack accordingly
|
||||
if isinstance(b.with_context, TorchFunctionModeVariable):
|
||||
cg.extend_output(
|
||||
b.resume_fn().try_except_torch_function_mode(
|
||||
cg.code_options, cleanup
|
||||
)
|
||||
)
|
||||
continue
|
||||
assert b.with_context is not None
|
||||
assert isinstance(
|
||||
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
|
||||
)
|
||||
assert isinstance(b.with_context, (ContextWrappingVariable))
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
|
|
@ -2295,7 +2301,10 @@ class InstructionTranslatorBase(
|
|||
):
|
||||
unimplemented(f"{inst.opname} {ctx}")
|
||||
|
||||
if isinstance(ctx, GenericContextWrappingVariable):
|
||||
if (
|
||||
isinstance(ctx, GenericContextWrappingVariable)
|
||||
and not ctx.supports_graph_breaks()
|
||||
):
|
||||
self.generic_context_manager_depth += 1
|
||||
|
||||
# Need this redundant check for mypy
|
||||
|
|
@ -2668,6 +2677,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
local_scope=f_locals,
|
||||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
),
|
||||
instructions=instructions,
|
||||
f_locals=f_locals,
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ def debug_insert_nops(
|
|||
local_scope=locals(),
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
|
|
|
|||
|
|
@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
|
|||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
"torch.set_default_device": UserFunctionVariable,
|
||||
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
||||
|
|
@ -2797,7 +2798,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.random.initial_seed",
|
||||
"torch.random.seed",
|
||||
"torch.return_types.pytree_register_structseq",
|
||||
"torch.set_default_device",
|
||||
"torch.set_default_dtype",
|
||||
"torch.set_default_tensor_type",
|
||||
"torch.set_deterministic_debug_mode",
|
||||
|
|
|
|||
|
|
@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
|||
from .torch_function import (
|
||||
build_torch_function_fn,
|
||||
TensorWithTFOverrideVariable,
|
||||
torch_function_mode_stack_state_mgr,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
|
|
@ -1668,15 +1669,16 @@ class VariableBuilder:
|
|||
# but warning is not the end of the world
|
||||
assert isinstance(value.base, np.nditer)
|
||||
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
with torch_function_mode_stack_state_mgr.temp_restore_stack():
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
|
||||
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||
|
|
|
|||
|
|
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
|
|||
if isinstance(args[0], UserFunctionVariable):
|
||||
return WrappedUserFunctionVariable(args[0], self)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
|
|
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
|||
tx.generic_context_manager_depth -= 1
|
||||
return x
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return False
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch grad requries grad"""
|
||||
|
|
|
|||
|
|
@ -162,7 +162,17 @@ def get_overridable_functions():
|
|||
|
||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||
|
||||
return set(chain(*get_overridable_functions_().values()))
|
||||
funcs = set(chain(*get_overridable_functions_().values()))
|
||||
more = {
|
||||
torch.ones,
|
||||
torch.ones_like,
|
||||
torch.zeros,
|
||||
torch.zeros_like,
|
||||
torch.empty,
|
||||
torch.full,
|
||||
}
|
||||
funcs.update(more)
|
||||
return funcs
|
||||
|
||||
|
||||
class BaseTorchVariable(VariableTracker):
|
||||
|
|
@ -838,6 +848,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
|
||||
@register(torch._C._get_function_stack_at)
|
||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
assert len(args) == 1 and not kwargs
|
||||
ind = args[0].as_python_constant()
|
||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
|
|
@ -855,7 +872,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
else:
|
||||
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
||||
|
||||
return None
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
return handlers
|
||||
|
||||
|
|
|
|||
|
|
@ -2,22 +2,35 @@
|
|||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from torch.overrides import (
|
||||
_get_overloaded_args,
|
||||
get_default_nowrap_functions,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||
from ..utils import (
|
||||
class_has_getattribute,
|
||||
clear_torch_function_mode_stack,
|
||||
get_safe_global_name,
|
||||
has_torch_function,
|
||||
is_tensor_base_attr_getter,
|
||||
set_torch_function_mode_stack,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
|
|
@ -63,6 +76,39 @@ banned_attrs = [
|
|||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_prev_stack_var_name():
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
return unique_id("___prev_torch_function_mode_stack")
|
||||
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
def __enter__(self):
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_torch_function_mode_stack(prev)
|
||||
|
||||
|
||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
|
|
@ -189,9 +235,26 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
|||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the funtion across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
|
|
@ -221,8 +284,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable):
|
|||
kwargs,
|
||||
)
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
unimplemented("enter/exit for torch function mode NYI")
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
|
|
|
|||
|
|
@ -409,10 +409,22 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .torch_function import TorchFunctionModeVariable
|
||||
|
||||
if issubclass(
|
||||
self.value, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
||||
self.value
|
||||
):
|
||||
var_cls = TorchFunctionModeVariable
|
||||
else:
|
||||
var_cls = GenericContextWrappingVariable
|
||||
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
self.source, self.value, GenericContextWrappingVariable, {}
|
||||
self.source, self.value, var_cls, {}
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user