Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)" (#136590)

This reverts commit 7743149b2b.

Reverts
* https://github.com/pytorch/pytorch/pull/135503
* https://github.com/pytorch/pytorch/pull/135502
* https://github.com/pytorch/pytorch/pull/135422

This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.

```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask

torch.set_default_device('cuda')

flex_attention = torch.compile(flex_attention, dynamic=False)

prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
    return prefix_lengths[b] >= kv

mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136590
Approved by: https://github.com/Chillee
This commit is contained in:
Animesh Jain 2024-09-25 09:36:53 -07:00 committed by PyTorch MergeBot
parent 529b6ab0bb
commit 289df45cee
18 changed files with 201 additions and 336 deletions

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch import torch
import torch._dynamo.test_case import torch._dynamo.test_case
@ -106,6 +107,70 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp) fn(inp)
self.assertEqual(cnt.frame_count, 4) self.assertEqual(cnt.frame_count, 4)
def _run_ignored_mode_types_test(self):
class IgnoredMode(BaseTorchFunctionMode):
pass
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt.__call__, fullgraph=True)
def fn(x):
return x + 1
inp = torch.ones(2, 2)
with patch(
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
):
# initial compile
fn(inp)
# no recompile, mode ignored
# note: the ref stack is length 0, and the stack we are checking against has length 2
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
with IgnoredMode(), IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 1)
# recompile due to new mode on the stack
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 2)
# recompile
# tests both ref stack len > runtime stack len for the above guard check
# and ref stack len < runtime stack len for the initial zero mode case
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# no recompile
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# This is tricky, basically the ignored modes are baked into the guard
# IgnoredMode will be ignored forever by that guard.
# This is okay since we don't expect to be modifying IGNORED_MODES
# in the middle of execution except for the purposes of testing.
torch._dynamo.reset()
with IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 4)
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_ignored_types_py(self):
self._run_ignored_mode_types_test()
def test_torch_function_mode_guards_ignored_types_cpp(self):
self._run_ignored_mode_types_test()
@torch._dynamo.config.patch("enable_cpp_guard_manager", False) @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_py(self): def test_torch_function_mode_guards_py(self):
self._run_torch_function_mode_guard_test() self._run_torch_function_mode_guard_test()
@ -396,94 +461,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -67,7 +67,7 @@ class GuardManager:
) -> None: ... ) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
def add_torch_function_mode_stack_guard( def add_torch_function_mode_stack_guard(
self, initial_stack, verbose_code_parts: list[str] self, initial_stack, ignored_types, verbose_code_parts: list[str]
) -> None: ... ) -> None: ...
class RootGuardManager(GuardManager): class RootGuardManager(GuardManager):

View File

@ -112,7 +112,6 @@ from .utils import (
troubleshooting_url, troubleshooting_url,
write_record_to_file, write_record_to_file,
) )
from .variables.torch_function import torch_function_mode_stack_state_mgr
np: Optional[ModuleType] np: Optional[ModuleType]
@ -211,18 +210,15 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug() cleanup = setup_compile_debug()
exit_stack = contextlib.ExitStack() exit_stack = contextlib.ExitStack()
exit_stack.enter_context( exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches() torch.fx._symbolic_trace._maybe_revert_all_patches()
) )
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
finally: finally:
cleanup.close() cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close() exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode) torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)

View File

@ -2344,12 +2344,15 @@ class CheckFunctionManager:
) )
if config.enable_cpp_guard_manager: if config.enable_cpp_guard_manager:
from .variables.torch_function import IGNORED_MODES
# Insert the global_state guard # Insert the global_state guard
assert self.guard_manager # to make mypy happy assert self.guard_manager # to make mypy happy
self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
self.guard_manager.root.add_torch_function_mode_stack_guard( self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack, self.torch_function_mode_stack,
list(IGNORED_MODES),
["___check_torch_function_mode_stack()"], ["___check_torch_function_mode_stack()"],
) )
# Clear references to torch_function modes held in the list # Clear references to torch_function modes held in the list
@ -2656,14 +2659,18 @@ def is_recompiles_verbose_enabled():
# this will only be used if cpp guards are disabled # this will only be used if cpp guards are disabled
def make_torch_function_mode_stack_guard(intial_stack): def make_torch_function_mode_stack_guard(intial_stack):
types = [type(x) for x in intial_stack] types = [type(x) for x in intial_stack]
from .variables.torch_function import IGNORED_MODES
def check_torch_function_mode_stack(): def check_torch_function_mode_stack():
cur_stack = get_torch_function_mode_stack() cur_stack = get_torch_function_mode_stack()
if len(cur_stack) != len(types): types_ = [ty for ty in types if ty not in IGNORED_MODES]
cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]
if len(cur_stack_) != len(types_):
return False return False
for ty, mode in zip(types, cur_stack): for ty, mode in zip(types_, cur_stack_):
if ty != type(mode): if ty != type(mode):
return False return False

View File

@ -78,6 +78,7 @@ from .utils import (
get_instruction_source_311, get_instruction_source_311,
get_locals_to_steal, get_locals_to_steal,
get_static_address_type, get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons, graph_break_reasons,
increment_op_count, increment_op_count,
lazy_format_graph_code, lazy_format_graph_code,
@ -249,7 +250,6 @@ class OutputGraph:
local_scope: Scope, local_scope: Scope,
global_scope: Scope, global_scope: Scope,
f_code, f_code,
torch_function_mode_stack,
): ):
super().__init__() super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)] self.tracers = [SubgraphTracer(self, export_root=export)]
@ -368,7 +368,7 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding # This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = torch_function_mode_stack self.torch_function_mode_stack = get_torch_function_mode_stack()
# Tracks if the output graph has a user defined allowed function in the # Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager # graph. This is used later to determine if we should fallback to eager
@ -1020,7 +1020,7 @@ class OutputGraph:
prefix_insts.clear() prefix_insts.clear()
for block in reversed(tx.block_stack): for block in reversed(tx.block_stack):
block.exit(tx, is_graph_break=reason.graph_break) block.exit(tx)
self.cleanup_graph() self.cleanup_graph()
tx.prune_dead_locals() tx.prune_dead_locals()

View File

@ -25,26 +25,6 @@ if TYPE_CHECKING:
sys as sys, sys as sys,
) )
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None): def index(iterator, item, start=0, end=None):
from itertools import islice from itertools import islice

View File

@ -90,26 +90,27 @@ class ReenterWith:
stack_index: int stack_index: int
target_values: Optional[Tuple[Any, ...]] = None target_values: Optional[Tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): # TODO(mlazos) - Uncomment with the reland of torch function mode support
""" # def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
Codegen based off of: # """
try: # Codegen based off of:
(rest) # try:
except: # (rest)
(restore previous tf mode stack) # except:
raise # (restore previous tf mode stack)
# raise
""" # """
from .variables.torch_function import get_prev_stack_var_name # from .variables.torch_function import get_prev_stack_var_name
setup_try_except, epilogue = _bytecode_from_template_with_split( # setup_try_except, epilogue = _bytecode_from_template_with_split(
_try_except_tf_mode_template, # _try_except_tf_mode_template,
self.stack_index, # self.stack_index,
varname_map={"stack_var_name": get_prev_stack_var_name()}, # varname_map={"stack_var_name": get_prev_stack_var_name()},
) # )
cleanup[:] = epilogue + cleanup # cleanup[:] = epilogue + cleanup
return setup_try_except # return setup_try_except
# If we do not want to destroy the stack, we can do the same thing as a # If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol # `SETUP_WITH` block, only that we store the context manager in a local_symbol

View File

@ -623,22 +623,11 @@ class SideEffects:
elif isinstance( elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable var, variables.torch_function.TorchFunctionModeStackVariable
): ):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null( cg.add_push_null(
lambda: cg.load_import_from( lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack" utils.__name__, "set_torch_function_mode_stack"
) )
) )
cg.foreach(var.symbolic_stack) cg.foreach(var.symbolic_stack)
cg.append_output( cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))

View File

@ -267,12 +267,13 @@ class BlockStackEntry:
else: else:
return ReenterWith(self.stack_index) return ReenterWith(self.stack_index)
def exit(self, tx, is_graph_break): def exit(self, tx):
if hasattr(self, "graph_break") and isinstance(
self.with_context, TorchFunctionModeVariable
):
return
assert self.with_context is not None assert self.with_context is not None
if ( return self.with_context.exit(tx)
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
class ReturnValueOp(Exception): class ReturnValueOp(Exception):
@ -638,17 +639,10 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = [] cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack # Reconstruct the context variable CLASS in the block stack
for b in self.block_stack: for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
cg.extend_output(
b.resume_fn().try_except_torch_function_mode(
cg.code_options, cleanup
)
)
continue
assert b.with_context is not None assert b.with_context is not None
assert isinstance(b.with_context, (ContextWrappingVariable)) assert isinstance(
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
)
b.with_context.reconstruct_type(cg) b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup)) cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions()) self.output.add_output_instructions(cg.get_instructions())
@ -2301,10 +2295,7 @@ class InstructionTranslatorBase(
): ):
unimplemented(f"{inst.opname} {ctx}") unimplemented(f"{inst.opname} {ctx}")
if ( if isinstance(ctx, GenericContextWrappingVariable):
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
self.generic_context_manager_depth += 1 self.generic_context_manager_depth += 1
# Need this redundant check for mypy # Need this redundant check for mypy
@ -2677,7 +2668,6 @@ class InstructionTranslator(InstructionTranslatorBase):
local_scope=f_locals, local_scope=f_locals,
global_scope=f_globals, global_scope=f_globals,
f_code=f_code, f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
), ),
instructions=instructions, instructions=instructions,
f_locals=f_locals, f_locals=f_locals,

View File

@ -187,7 +187,6 @@ def debug_insert_nops(
local_scope=locals(), local_scope=locals(),
global_scope=globals(), global_scope=globals(),
f_code=frame.f_code, f_code=frame.f_code,
torch_function_mode_stack=[],
) )
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))

View File

@ -304,7 +304,6 @@ manual_torch_name_rule_map = {
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
"torch.set_default_device": UserFunctionVariable,
"torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable,
"torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable,
"torch.sparse_csc_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable,
@ -2802,6 +2801,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.random.initial_seed", "torch.random.initial_seed",
"torch.random.seed", "torch.random.seed",
"torch.return_types.pytree_register_structseq", "torch.return_types.pytree_register_structseq",
"torch.set_default_device",
"torch.set_default_dtype", "torch.set_default_dtype",
"torch.set_default_tensor_type", "torch.set_default_tensor_type",
"torch.set_deterministic_debug_mode", "torch.set_deterministic_debug_mode",

View File

@ -3097,10 +3097,16 @@ def is_parameter_freezing():
return torch._inductor.config.freezing and not torch.is_grad_enabled() return torch._inductor.config.freezing and not torch.is_grad_enabled()
def get_torch_function_mode_stack(): def get_torch_function_mode_stack(filter_ignored=True):
return [ from .variables.torch_function import IGNORED_MODES
stack = [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
] ]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
return stack
def get_torch_function_mode_stack_at(ind): def get_torch_function_mode_stack_at(ind):

View File

@ -204,7 +204,6 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import ( from .torch_function import (
build_torch_function_fn, build_torch_function_fn,
TensorWithTFOverrideVariable, TensorWithTFOverrideVariable,
torch_function_mode_stack_state_mgr,
TorchFunctionModeVariable, TorchFunctionModeVariable,
) )
from .user_defined import ( from .user_defined import (
@ -1671,16 +1670,15 @@ class VariableBuilder:
# but warning is not the end of the world # but warning is not the end of the world
assert isinstance(value.base, np.nditer) assert isinstance(value.base, np.nditer)
with torch_function_mode_stack_state_mgr.temp_restore_stack(): try:
try: tensor_value = _util._try_convert_to_tensor(value)
tensor_value = _util._try_convert_to_tensor(value) if readonly:
if readonly: from torch._prims_common import clone_preserve_strides
from torch._prims_common import clone_preserve_strides
tensor_value = clone_preserve_strides(tensor_value) tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e: except NotImplementedError as e:
# failed to convert to tensor, graph break # failed to convert to tensor, graph break
unimplemented(str(e)) unimplemented(str(e))
# We do this because we want the full behavior of guarding the numpy ndarray as if it were # We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here

View File

@ -125,12 +125,6 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable): if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self) return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable): class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are # Some methods in ContextWrappingVariable assumes the arguments are
@ -189,12 +183,6 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1 tx.generic_context_manager_depth -= 1
return x return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad""" """represents torch grad requries grad"""

View File

@ -160,17 +160,7 @@ def get_overridable_functions():
from torch.overrides import get_overridable_functions as get_overridable_functions_ from torch.overrides import get_overridable_functions as get_overridable_functions_
funcs = set(chain(*get_overridable_functions_().values())) return set(chain(*get_overridable_functions_().values()))
more = {
torch.ones,
torch.ones_like,
torch.zeros,
torch.zeros_like,
torch.empty,
torch.full,
}
funcs.update(more)
return funcs
class BaseTorchVariable(VariableTracker): class BaseTorchVariable(VariableTracker):
@ -846,13 +836,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
len(tx.symbolic_torch_function_state.mode_stack) len(tx.symbolic_torch_function_state.mode_stack)
) )
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]
@register(torch.set_default_device) @register(torch.set_default_device)
def handle_set_default_device( def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs self, tx: "InstructionTranslator", *args, **kwargs
@ -870,7 +853,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else: else:
TorchFunctionModeStackVariable.register_device_context_insertion(tx) TorchFunctionModeStackVariable.register_device_context_insertion(tx)
return ConstantVariable.create(None) return None
return handlers return handlers

View File

@ -2,35 +2,22 @@
import collections import collections
import contextlib import contextlib
import functools
import inspect import inspect
from typing import Deque, Dict, List, TYPE_CHECKING from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C import torch._C
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._guards import Source from torch._guards import Source
from torch.overrides import ( from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext from torch.utils._device import DeviceContext
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import ( from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker from .base import VariableTracker
from .constant import ConstantVariable from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
from .lists import TupleVariable from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable from .tensor import TensorSubclassVariable, TensorVariable
@ -69,38 +56,11 @@ banned_attrs = [
if is_tensor_base_attr_getter(fn) if is_tensor_base_attr_getter(fn)
] ]
# Today set default device is placed in the graph and guarded on separately
@functools.lru_cache(None) # so we should not trace through it. In the future we can trace it once
def get_prev_stack_var_name(): # mode tracing is implemented and not put in the graph, but this is more
from ..bytecode_transformation import unique_id # of a BE project and can be evaluated later
IGNORED_MODES = {DeviceContext}
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState: class SymbolicTorchFunctionState:
@ -229,26 +189,9 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable): class TorchFunctionModeVariable(ContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs): def __init__(self, value, source=None, **kwargs):
if value is not None: super().__init__(value, **kwargs)
super().__init__(value, **kwargs)
self.value = value self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source self.source = source
@ -278,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
kwargs, kwargs,
) )
def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values):
from .torch import TorchInGraphFunctionVariable unimplemented("enter/exit for torch function mode NYI")
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs): def _get_all_args(args, kwargs):

View File

@ -417,22 +417,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source and self.source
and not is_forbidden_context_manager(self.value) and not is_forbidden_context_manager(self.value)
): ):
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new( cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, var_cls, {} self.source, self.value, GenericContextWrappingVariable, {}
) )
cm_obj.call_method(tx, "__init__", args, kwargs) cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj return cm_obj

View File

@ -2537,40 +2537,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public: public:
TORCH_FUNCTION_MODE_STACK( TORCH_FUNCTION_MODE_STACK(
const py::list& initial_stack, const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts) py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() { : LeafGuard(std::move(verbose_code_parts)),
_ref_stack(),
_ignored_types() {
Py_ssize_t len = PyList_Size(initial_stack.ptr()); Py_ssize_t len = PyList_Size(initial_stack.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) { for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
auto type = Py_TYPE(mode); auto type = Py_TYPE(mode);
this->_ref_stack.push_back(type); this->_ref_stack.push_back(type);
} }
len = PyList_Size(ignored_types.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* type_obj =
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
if (PyType_Check(type_obj) == 0) {
PyErr_SetString(
PyExc_TypeError, "ignored_types should contain a list of types");
return;
}
PyTypeObject* type = (PyTypeObject*)type_obj;
this->_ignored_types.insert(type);
}
} }
bool check_nopybind(PyObject* value) override { bool check_nopybind(PyObject* value) override {
// Ignore value arg, only used to satisfy the interface // Ignore value arg, only used to satisfy the interface
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); size_t ref_ind = 0;
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
const size_t ref_stack_size = this->_ref_stack.size(); const size_t ref_stack_size = this->_ref_stack.size();
if (len != ref_stack_size) { int64_t idx = 0;
return false; while ((idx < len) && (ref_ind < ref_stack_size)) {
}
for (int64_t idx = 0; (size_t)idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode = std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx); at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (mode_type != _ref_stack.at(idx)) { bool act_ignored = this->_ignored_types.count(mode_type) > 0;
bool ref_ignored =
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
// skip ignored types
if (act_ignored && ref_ignored) {
idx++;
ref_ind++;
continue;
} else if (ref_ignored) {
ref_ind++;
continue;
} else if (act_ignored) {
idx++;
continue;
}
// if we already have more non-ignored modes than the ref stack
// or if the mode doesn't match at the current index, return false
else if (mode_type != _ref_stack.at(ref_ind)) {
return false;
}
ref_ind++;
idx++;
}
for (; ref_ind < ref_stack_size; ref_ind++) {
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
return false; return false;
} }
} }
return true; for (; idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (!(this->_ignored_types.count(mode_type) > 0)) {
return false;
}
}
return ref_ind == ref_stack_size && idx == len;
} }
private: private:
std::vector<PyTypeObject*> _ref_stack; std::vector<PyTypeObject*> _ref_stack;
std::set<PyTypeObject*> _ignored_types;
}; };
class TENSOR_MATCH : public LeafGuard { class TENSOR_MATCH : public LeafGuard {
@ -3735,7 +3785,7 @@ PyObject* torch_c_dynamo_guards_init() {
LeafGuard, LeafGuard,
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>( std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
py_m, "TORCH_FUNCTION_MODE_STACK") py_m, "TORCH_FUNCTION_MODE_STACK")
.def(py::init<py::list, py::list>()) .def(py::init<py::list, py::list, py::list>())
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check); .def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>( py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
py_m, "DATA_PTR_MATCH") py_m, "DATA_PTR_MATCH")
@ -3972,9 +4022,10 @@ PyObject* torch_c_dynamo_guards_init() {
"add_torch_function_mode_stack_guard", "add_torch_function_mode_stack_guard",
[](GuardManager& self, [](GuardManager& self,
const py::list& initial_stack, const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts) -> void { py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>( self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
initial_stack, std::move(verbose_code_parts))); initial_stack, ignored_types, std::move(verbose_code_parts)));
}) })
.def( .def(
"add_data_ptr_guard", "add_data_ptr_guard",