mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit 7743149b2b.
Reverts
* https://github.com/pytorch/pytorch/pull/135503
* https://github.com/pytorch/pytorch/pull/135502
* https://github.com/pytorch/pytorch/pull/135422
This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.
```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask
torch.set_default_device('cuda')
flex_attention = torch.compile(flex_attention, dynamic=False)
prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
return prefix_lengths[b] >= kv
mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136590
Approved by: https://github.com/Chillee
This commit is contained in:
parent
529b6ab0bb
commit
289df45cee
|
|
@ -1,4 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
|
@ -106,6 +107,70 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||||
fn(inp)
|
fn(inp)
|
||||||
self.assertEqual(cnt.frame_count, 4)
|
self.assertEqual(cnt.frame_count, 4)
|
||||||
|
|
||||||
|
def _run_ignored_mode_types_test(self):
|
||||||
|
class IgnoredMode(BaseTorchFunctionMode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
|
@torch.compile(backend=cnt.__call__, fullgraph=True)
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
inp = torch.ones(2, 2)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
|
||||||
|
):
|
||||||
|
# initial compile
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
# no recompile, mode ignored
|
||||||
|
# note: the ref stack is length 0, and the stack we are checking against has length 2
|
||||||
|
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
|
||||||
|
with IgnoredMode(), IgnoredMode():
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
self.assertEqual(cnt.frame_count, 1)
|
||||||
|
|
||||||
|
# recompile due to new mode on the stack
|
||||||
|
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
# recompile
|
||||||
|
# tests both ref stack len > runtime stack len for the above guard check
|
||||||
|
# and ref stack len < runtime stack len for the initial zero mode case
|
||||||
|
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
|
|
||||||
|
# no recompile
|
||||||
|
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
|
|
||||||
|
# This is tricky, basically the ignored modes are baked into the guard
|
||||||
|
# IgnoredMode will be ignored forever by that guard.
|
||||||
|
# This is okay since we don't expect to be modifying IGNORED_MODES
|
||||||
|
# in the middle of execution except for the purposes of testing.
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
with IgnoredMode():
|
||||||
|
fn(inp)
|
||||||
|
|
||||||
|
self.assertEqual(cnt.frame_count, 4)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||||
|
def test_torch_function_mode_guards_ignored_types_py(self):
|
||||||
|
self._run_ignored_mode_types_test()
|
||||||
|
|
||||||
|
def test_torch_function_mode_guards_ignored_types_cpp(self):
|
||||||
|
self._run_ignored_mode_types_test()
|
||||||
|
|
||||||
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||||
def test_torch_function_mode_guards_py(self):
|
def test_torch_function_mode_guards_py(self):
|
||||||
self._run_torch_function_mode_guard_test()
|
self._run_torch_function_mode_guard_test()
|
||||||
|
|
@ -396,94 +461,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
def test_torch_function_mode_enter_exit(self):
|
|
||||||
def fn(x, y):
|
|
||||||
with TestMode():
|
|
||||||
o = torch.add(x, 3)
|
|
||||||
|
|
||||||
return torch.add(o, y)
|
|
||||||
|
|
||||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
||||||
fn_opt = torch.compile(fn, fullgraph=True)
|
|
||||||
|
|
||||||
expected = fn(*inp)
|
|
||||||
actual = fn_opt(*inp)
|
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
def test_torch_function_mode_graph_break(self):
|
|
||||||
def fn(x, y):
|
|
||||||
with TestMode():
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
o = torch.add(x, 3)
|
|
||||||
|
|
||||||
return torch.add(o, y)
|
|
||||||
|
|
||||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
||||||
fn_opt = torch.compile(fn)
|
|
||||||
|
|
||||||
expected = fn(*inp)
|
|
||||||
actual = fn_opt(*inp)
|
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
def test_torch_function_mode_and_pop_graph_break(self):
|
|
||||||
def fn(x, y):
|
|
||||||
with TestMode():
|
|
||||||
z = _pop_torch_function_stack()
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
_push_on_torch_function_stack(z)
|
|
||||||
o = torch.add(x, 3)
|
|
||||||
|
|
||||||
return torch.add(o, y)
|
|
||||||
|
|
||||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
||||||
fn_opt = torch.compile(fn)
|
|
||||||
|
|
||||||
expected = fn(*inp)
|
|
||||||
actual = fn_opt(*inp)
|
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
def test_torch_function_mode_restore_on_exc(self):
|
|
||||||
@torch._dynamo.disable()
|
|
||||||
def err():
|
|
||||||
raise RuntimeError("test")
|
|
||||||
|
|
||||||
@torch.compile()
|
|
||||||
def fn(x):
|
|
||||||
with TestMode():
|
|
||||||
x += 1
|
|
||||||
err()
|
|
||||||
x += 2
|
|
||||||
return x
|
|
||||||
|
|
||||||
try:
|
|
||||||
fn(torch.ones(2, 2))
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
self.assertEqual(_len_torch_function_stack(), 0)
|
|
||||||
|
|
||||||
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
|
||||||
def fn(x, y):
|
|
||||||
with TestMode():
|
|
||||||
z = _pop_torch_function_stack()
|
|
||||||
z.y = 5
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
_push_on_torch_function_stack(z)
|
|
||||||
o = torch.add(x, 3)
|
|
||||||
o = torch.mul(o, z.y)
|
|
||||||
|
|
||||||
return torch.add(o, y)
|
|
||||||
|
|
||||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
||||||
fn_opt = torch.compile(fn)
|
|
||||||
|
|
||||||
expected = fn(*inp)
|
|
||||||
actual = fn_opt(*inp)
|
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ class GuardManager:
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
||||||
def add_torch_function_mode_stack_guard(
|
def add_torch_function_mode_stack_guard(
|
||||||
self, initial_stack, verbose_code_parts: list[str]
|
self, initial_stack, ignored_types, verbose_code_parts: list[str]
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
class RootGuardManager(GuardManager):
|
class RootGuardManager(GuardManager):
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,6 @@ from .utils import (
|
||||||
troubleshooting_url,
|
troubleshooting_url,
|
||||||
write_record_to_file,
|
write_record_to_file,
|
||||||
)
|
)
|
||||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
|
||||||
|
|
||||||
|
|
||||||
np: Optional[ModuleType]
|
np: Optional[ModuleType]
|
||||||
|
|
@ -211,18 +210,15 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||||
cleanup = setup_compile_debug()
|
cleanup = setup_compile_debug()
|
||||||
|
|
||||||
exit_stack = contextlib.ExitStack()
|
exit_stack = contextlib.ExitStack()
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||||
)
|
)
|
||||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
cleanup.close()
|
cleanup.close()
|
||||||
assert (
|
|
||||||
torch._C._len_torch_function_stack() == 0
|
|
||||||
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
|
|
||||||
exit_stack.close()
|
exit_stack.close()
|
||||||
torch._C._set_grad_enabled(prior_grad_mode)
|
torch._C._set_grad_enabled(prior_grad_mode)
|
||||||
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
||||||
|
|
|
||||||
|
|
@ -2344,12 +2344,15 @@ class CheckFunctionManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.enable_cpp_guard_manager:
|
if config.enable_cpp_guard_manager:
|
||||||
|
from .variables.torch_function import IGNORED_MODES
|
||||||
|
|
||||||
# Insert the global_state guard
|
# Insert the global_state guard
|
||||||
assert self.guard_manager # to make mypy happy
|
assert self.guard_manager # to make mypy happy
|
||||||
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
|
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
|
||||||
|
|
||||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||||
self.torch_function_mode_stack,
|
self.torch_function_mode_stack,
|
||||||
|
list(IGNORED_MODES),
|
||||||
["___check_torch_function_mode_stack()"],
|
["___check_torch_function_mode_stack()"],
|
||||||
)
|
)
|
||||||
# Clear references to torch_function modes held in the list
|
# Clear references to torch_function modes held in the list
|
||||||
|
|
@ -2656,14 +2659,18 @@ def is_recompiles_verbose_enabled():
|
||||||
# this will only be used if cpp guards are disabled
|
# this will only be used if cpp guards are disabled
|
||||||
def make_torch_function_mode_stack_guard(intial_stack):
|
def make_torch_function_mode_stack_guard(intial_stack):
|
||||||
types = [type(x) for x in intial_stack]
|
types = [type(x) for x in intial_stack]
|
||||||
|
from .variables.torch_function import IGNORED_MODES
|
||||||
|
|
||||||
def check_torch_function_mode_stack():
|
def check_torch_function_mode_stack():
|
||||||
cur_stack = get_torch_function_mode_stack()
|
cur_stack = get_torch_function_mode_stack()
|
||||||
|
|
||||||
if len(cur_stack) != len(types):
|
types_ = [ty for ty in types if ty not in IGNORED_MODES]
|
||||||
|
cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]
|
||||||
|
|
||||||
|
if len(cur_stack_) != len(types_):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for ty, mode in zip(types, cur_stack):
|
for ty, mode in zip(types_, cur_stack_):
|
||||||
if ty != type(mode):
|
if ty != type(mode):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,7 @@ from .utils import (
|
||||||
get_instruction_source_311,
|
get_instruction_source_311,
|
||||||
get_locals_to_steal,
|
get_locals_to_steal,
|
||||||
get_static_address_type,
|
get_static_address_type,
|
||||||
|
get_torch_function_mode_stack,
|
||||||
graph_break_reasons,
|
graph_break_reasons,
|
||||||
increment_op_count,
|
increment_op_count,
|
||||||
lazy_format_graph_code,
|
lazy_format_graph_code,
|
||||||
|
|
@ -249,7 +250,6 @@ class OutputGraph:
|
||||||
local_scope: Scope,
|
local_scope: Scope,
|
||||||
global_scope: Scope,
|
global_scope: Scope,
|
||||||
f_code,
|
f_code,
|
||||||
torch_function_mode_stack,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||||
|
|
@ -368,7 +368,7 @@ class OutputGraph:
|
||||||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||||
# This records the initial torch function mode stack for guarding
|
# This records the initial torch function mode stack for guarding
|
||||||
self.torch_function_mode_stack = torch_function_mode_stack
|
self.torch_function_mode_stack = get_torch_function_mode_stack()
|
||||||
|
|
||||||
# Tracks if the output graph has a user defined allowed function in the
|
# Tracks if the output graph has a user defined allowed function in the
|
||||||
# graph. This is used later to determine if we should fallback to eager
|
# graph. This is used later to determine if we should fallback to eager
|
||||||
|
|
@ -1020,7 +1020,7 @@ class OutputGraph:
|
||||||
prefix_insts.clear()
|
prefix_insts.clear()
|
||||||
|
|
||||||
for block in reversed(tx.block_stack):
|
for block in reversed(tx.block_stack):
|
||||||
block.exit(tx, is_graph_break=reason.graph_break)
|
block.exit(tx)
|
||||||
|
|
||||||
self.cleanup_graph()
|
self.cleanup_graph()
|
||||||
tx.prune_dead_locals()
|
tx.prune_dead_locals()
|
||||||
|
|
|
||||||
|
|
@ -25,26 +25,6 @@ if TYPE_CHECKING:
|
||||||
sys as sys,
|
sys as sys,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.overrides import BaseTorchFunctionMode
|
|
||||||
|
|
||||||
|
|
||||||
# These classes handle support for TorchFunctionModes across
|
|
||||||
# graph breaks
|
|
||||||
# Today the TorchFunctionMode enter (for the classes we support)
|
|
||||||
# simply pushes the mode onto the stack. Since after this occurs
|
|
||||||
# the stack is mutated, and we replay these mutations, we don't need
|
|
||||||
# any cleanup logic to be run once the graph break occurs, we simply replay
|
|
||||||
# these mutations to ensure at the graph break the torch function mode stack is correct
|
|
||||||
# and reconstruct the torch function mode stack normally
|
|
||||||
# when we compile the resume function on the other side of the break.
|
|
||||||
# However, to ensure we exit properly
|
|
||||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
|
|
||||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
|
|
||||||
# the stack state is correct.
|
|
||||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
|
|
||||||
def __enter__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def index(iterator, item, start=0, end=None):
|
def index(iterator, item, start=0, end=None):
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
|
||||||
|
|
@ -90,26 +90,27 @@ class ReenterWith:
|
||||||
stack_index: int
|
stack_index: int
|
||||||
target_values: Optional[Tuple[Any, ...]] = None
|
target_values: Optional[Tuple[Any, ...]] = None
|
||||||
|
|
||||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
# TODO(mlazos) - Uncomment with the reland of torch function mode support
|
||||||
"""
|
# def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||||
Codegen based off of:
|
# """
|
||||||
try:
|
# Codegen based off of:
|
||||||
(rest)
|
# try:
|
||||||
except:
|
# (rest)
|
||||||
(restore previous tf mode stack)
|
# except:
|
||||||
raise
|
# (restore previous tf mode stack)
|
||||||
|
# raise
|
||||||
|
|
||||||
"""
|
# """
|
||||||
from .variables.torch_function import get_prev_stack_var_name
|
# from .variables.torch_function import get_prev_stack_var_name
|
||||||
|
|
||||||
setup_try_except, epilogue = _bytecode_from_template_with_split(
|
# setup_try_except, epilogue = _bytecode_from_template_with_split(
|
||||||
_try_except_tf_mode_template,
|
# _try_except_tf_mode_template,
|
||||||
self.stack_index,
|
# self.stack_index,
|
||||||
varname_map={"stack_var_name": get_prev_stack_var_name()},
|
# varname_map={"stack_var_name": get_prev_stack_var_name()},
|
||||||
)
|
# )
|
||||||
cleanup[:] = epilogue + cleanup
|
# cleanup[:] = epilogue + cleanup
|
||||||
|
|
||||||
return setup_try_except
|
# return setup_try_except
|
||||||
|
|
||||||
# If we do not want to destroy the stack, we can do the same thing as a
|
# If we do not want to destroy the stack, we can do the same thing as a
|
||||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||||
|
|
|
||||||
|
|
@ -623,22 +623,11 @@ class SideEffects:
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
var, variables.torch_function.TorchFunctionModeStackVariable
|
var, variables.torch_function.TorchFunctionModeStackVariable
|
||||||
):
|
):
|
||||||
# Needed in the finally block for stack restoration
|
|
||||||
cg.add_push_null(
|
|
||||||
lambda: cg.load_import_from(
|
|
||||||
utils.__name__, "get_torch_function_mode_stack"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cg.call_function(0, False)
|
|
||||||
name = variables.torch_function.get_prev_stack_var_name()
|
|
||||||
cg.code_options["co_varnames"] += (name,)
|
|
||||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
|
||||||
cg.add_push_null(
|
cg.add_push_null(
|
||||||
lambda: cg.load_import_from(
|
lambda: cg.load_import_from(
|
||||||
utils.__name__, "set_torch_function_mode_stack"
|
utils.__name__, "set_torch_function_mode_stack"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cg.foreach(var.symbolic_stack)
|
cg.foreach(var.symbolic_stack)
|
||||||
cg.append_output(
|
cg.append_output(
|
||||||
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
||||||
|
|
|
||||||
|
|
@ -267,12 +267,13 @@ class BlockStackEntry:
|
||||||
else:
|
else:
|
||||||
return ReenterWith(self.stack_index)
|
return ReenterWith(self.stack_index)
|
||||||
|
|
||||||
def exit(self, tx, is_graph_break):
|
def exit(self, tx):
|
||||||
|
if hasattr(self, "graph_break") and isinstance(
|
||||||
|
self.with_context, TorchFunctionModeVariable
|
||||||
|
):
|
||||||
|
return
|
||||||
assert self.with_context is not None
|
assert self.with_context is not None
|
||||||
if (
|
return self.with_context.exit(tx)
|
||||||
is_graph_break and self.with_context.exit_on_graph_break()
|
|
||||||
) or not is_graph_break:
|
|
||||||
return self.with_context.exit(tx)
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnValueOp(Exception):
|
class ReturnValueOp(Exception):
|
||||||
|
|
@ -638,17 +639,10 @@ def break_graph_if_unsupported(*, push):
|
||||||
cleanup: List[Instruction] = []
|
cleanup: List[Instruction] = []
|
||||||
# Reconstruct the context variable CLASS in the block stack
|
# Reconstruct the context variable CLASS in the block stack
|
||||||
for b in self.block_stack:
|
for b in self.block_stack:
|
||||||
# Don't exit any modes we have entered,
|
|
||||||
# output bytecode will mutate the tf mode stack accordingly
|
|
||||||
if isinstance(b.with_context, TorchFunctionModeVariable):
|
|
||||||
cg.extend_output(
|
|
||||||
b.resume_fn().try_except_torch_function_mode(
|
|
||||||
cg.code_options, cleanup
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
assert b.with_context is not None
|
assert b.with_context is not None
|
||||||
assert isinstance(b.with_context, (ContextWrappingVariable))
|
assert isinstance(
|
||||||
|
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
|
||||||
|
)
|
||||||
b.with_context.reconstruct_type(cg)
|
b.with_context.reconstruct_type(cg)
|
||||||
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
|
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
|
||||||
self.output.add_output_instructions(cg.get_instructions())
|
self.output.add_output_instructions(cg.get_instructions())
|
||||||
|
|
@ -2301,10 +2295,7 @@ class InstructionTranslatorBase(
|
||||||
):
|
):
|
||||||
unimplemented(f"{inst.opname} {ctx}")
|
unimplemented(f"{inst.opname} {ctx}")
|
||||||
|
|
||||||
if (
|
if isinstance(ctx, GenericContextWrappingVariable):
|
||||||
isinstance(ctx, GenericContextWrappingVariable)
|
|
||||||
and not ctx.supports_graph_breaks()
|
|
||||||
):
|
|
||||||
self.generic_context_manager_depth += 1
|
self.generic_context_manager_depth += 1
|
||||||
|
|
||||||
# Need this redundant check for mypy
|
# Need this redundant check for mypy
|
||||||
|
|
@ -2677,7 +2668,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||||
local_scope=f_locals,
|
local_scope=f_locals,
|
||||||
global_scope=f_globals,
|
global_scope=f_globals,
|
||||||
f_code=f_code,
|
f_code=f_code,
|
||||||
torch_function_mode_stack=torch_function_mode_stack,
|
|
||||||
),
|
),
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
f_locals=f_locals,
|
f_locals=f_locals,
|
||||||
|
|
|
||||||
|
|
@ -187,7 +187,6 @@ def debug_insert_nops(
|
||||||
local_scope=locals(),
|
local_scope=locals(),
|
||||||
global_scope=globals(),
|
global_scope=globals(),
|
||||||
f_code=frame.f_code,
|
f_code=frame.f_code,
|
||||||
torch_function_mode_stack=[],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||||
|
|
|
||||||
|
|
@ -304,7 +304,6 @@ manual_torch_name_rule_map = {
|
||||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||||
"torch.set_default_device": UserFunctionVariable,
|
|
||||||
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
||||||
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
||||||
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
||||||
|
|
@ -2802,6 +2801,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||||
"torch.random.initial_seed",
|
"torch.random.initial_seed",
|
||||||
"torch.random.seed",
|
"torch.random.seed",
|
||||||
"torch.return_types.pytree_register_structseq",
|
"torch.return_types.pytree_register_structseq",
|
||||||
|
"torch.set_default_device",
|
||||||
"torch.set_default_dtype",
|
"torch.set_default_dtype",
|
||||||
"torch.set_default_tensor_type",
|
"torch.set_default_tensor_type",
|
||||||
"torch.set_deterministic_debug_mode",
|
"torch.set_deterministic_debug_mode",
|
||||||
|
|
|
||||||
|
|
@ -3097,10 +3097,16 @@ def is_parameter_freezing():
|
||||||
return torch._inductor.config.freezing and not torch.is_grad_enabled()
|
return torch._inductor.config.freezing and not torch.is_grad_enabled()
|
||||||
|
|
||||||
|
|
||||||
def get_torch_function_mode_stack():
|
def get_torch_function_mode_stack(filter_ignored=True):
|
||||||
return [
|
from .variables.torch_function import IGNORED_MODES
|
||||||
|
|
||||||
|
stack = [
|
||||||
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
|
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
|
||||||
]
|
]
|
||||||
|
if filter_ignored:
|
||||||
|
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
|
||||||
|
|
||||||
|
return stack
|
||||||
|
|
||||||
|
|
||||||
def get_torch_function_mode_stack_at(ind):
|
def get_torch_function_mode_stack_at(ind):
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,6 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
||||||
from .torch_function import (
|
from .torch_function import (
|
||||||
build_torch_function_fn,
|
build_torch_function_fn,
|
||||||
TensorWithTFOverrideVariable,
|
TensorWithTFOverrideVariable,
|
||||||
torch_function_mode_stack_state_mgr,
|
|
||||||
TorchFunctionModeVariable,
|
TorchFunctionModeVariable,
|
||||||
)
|
)
|
||||||
from .user_defined import (
|
from .user_defined import (
|
||||||
|
|
@ -1671,16 +1670,15 @@ class VariableBuilder:
|
||||||
# but warning is not the end of the world
|
# but warning is not the end of the world
|
||||||
assert isinstance(value.base, np.nditer)
|
assert isinstance(value.base, np.nditer)
|
||||||
|
|
||||||
with torch_function_mode_stack_state_mgr.temp_restore_stack():
|
try:
|
||||||
try:
|
tensor_value = _util._try_convert_to_tensor(value)
|
||||||
tensor_value = _util._try_convert_to_tensor(value)
|
if readonly:
|
||||||
if readonly:
|
from torch._prims_common import clone_preserve_strides
|
||||||
from torch._prims_common import clone_preserve_strides
|
|
||||||
|
|
||||||
tensor_value = clone_preserve_strides(tensor_value)
|
tensor_value = clone_preserve_strides(tensor_value)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
# failed to convert to tensor, graph break
|
# failed to convert to tensor, graph break
|
||||||
unimplemented(str(e))
|
unimplemented(str(e))
|
||||||
|
|
||||||
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
||||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||||
|
|
|
||||||
|
|
@ -125,12 +125,6 @@ class ContextWrappingVariable(VariableTracker):
|
||||||
if isinstance(args[0], UserFunctionVariable):
|
if isinstance(args[0], UserFunctionVariable):
|
||||||
return WrappedUserFunctionVariable(args[0], self)
|
return WrappedUserFunctionVariable(args[0], self)
|
||||||
|
|
||||||
def supports_graph_breaks(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def exit_on_graph_break(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||||
|
|
@ -189,12 +183,6 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||||
tx.generic_context_manager_depth -= 1
|
tx.generic_context_manager_depth -= 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def supports_graph_breaks(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def exit_on_graph_break(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||||
"""represents torch grad requries grad"""
|
"""represents torch grad requries grad"""
|
||||||
|
|
|
||||||
|
|
@ -160,17 +160,7 @@ def get_overridable_functions():
|
||||||
|
|
||||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||||
|
|
||||||
funcs = set(chain(*get_overridable_functions_().values()))
|
return set(chain(*get_overridable_functions_().values()))
|
||||||
more = {
|
|
||||||
torch.ones,
|
|
||||||
torch.ones_like,
|
|
||||||
torch.zeros,
|
|
||||||
torch.zeros_like,
|
|
||||||
torch.empty,
|
|
||||||
torch.full,
|
|
||||||
}
|
|
||||||
funcs.update(more)
|
|
||||||
return funcs
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTorchVariable(VariableTracker):
|
class BaseTorchVariable(VariableTracker):
|
||||||
|
|
@ -846,13 +836,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
len(tx.symbolic_torch_function_state.mode_stack)
|
len(tx.symbolic_torch_function_state.mode_stack)
|
||||||
)
|
)
|
||||||
|
|
||||||
@register(torch._C._get_function_stack_at)
|
|
||||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
|
||||||
assert len(args) == 1 and not kwargs
|
|
||||||
ind = args[0].as_python_constant()
|
|
||||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
|
||||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
|
||||||
|
|
||||||
@register(torch.set_default_device)
|
@register(torch.set_default_device)
|
||||||
def handle_set_default_device(
|
def handle_set_default_device(
|
||||||
self, tx: "InstructionTranslator", *args, **kwargs
|
self, tx: "InstructionTranslator", *args, **kwargs
|
||||||
|
|
@ -870,7 +853,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
else:
|
else:
|
||||||
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
||||||
|
|
||||||
return ConstantVariable.create(None)
|
return None
|
||||||
|
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,35 +2,22 @@
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||||
|
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._guards import Source
|
from torch._guards import Source
|
||||||
from torch.overrides import (
|
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||||
_get_overloaded_args,
|
|
||||||
get_default_nowrap_functions,
|
|
||||||
TorchFunctionMode,
|
|
||||||
)
|
|
||||||
from torch.utils._device import DeviceContext
|
from torch.utils._device import DeviceContext
|
||||||
|
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder, install_guard
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..polyfills import NoEnterTorchFunctionMode
|
|
||||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||||
from ..utils import (
|
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||||
class_has_getattribute,
|
|
||||||
clear_torch_function_mode_stack,
|
|
||||||
get_safe_global_name,
|
|
||||||
has_torch_function,
|
|
||||||
is_tensor_base_attr_getter,
|
|
||||||
set_torch_function_mode_stack,
|
|
||||||
)
|
|
||||||
from .base import VariableTracker
|
from .base import VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
from .ctx_manager import GenericContextWrappingVariable
|
from .ctx_manager import ContextWrappingVariable
|
||||||
from .lazy import LazyVariableTracker
|
from .lazy import LazyVariableTracker
|
||||||
from .lists import TupleVariable
|
from .lists import TupleVariable
|
||||||
from .tensor import TensorSubclassVariable, TensorVariable
|
from .tensor import TensorSubclassVariable, TensorVariable
|
||||||
|
|
@ -69,38 +56,11 @@ banned_attrs = [
|
||||||
if is_tensor_base_attr_getter(fn)
|
if is_tensor_base_attr_getter(fn)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Today set default device is placed in the graph and guarded on separately
|
||||||
@functools.lru_cache(None)
|
# so we should not trace through it. In the future we can trace it once
|
||||||
def get_prev_stack_var_name():
|
# mode tracing is implemented and not put in the graph, but this is more
|
||||||
from ..bytecode_transformation import unique_id
|
# of a BE project and can be evaluated later
|
||||||
|
IGNORED_MODES = {DeviceContext}
|
||||||
return unique_id("___prev_torch_function_mode_stack")
|
|
||||||
|
|
||||||
|
|
||||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
|
||||||
class TorchFunctionModeStackStateManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.stack = []
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
|
||||||
clear_torch_function_mode_stack()
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
set_torch_function_mode_stack(self.stack)
|
|
||||||
self.stack = []
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def temp_restore_stack(self):
|
|
||||||
prev = torch.overrides._get_current_function_mode_stack()
|
|
||||||
set_torch_function_mode_stack(self.stack)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
set_torch_function_mode_stack(prev)
|
|
||||||
|
|
||||||
|
|
||||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
|
||||||
|
|
||||||
|
|
||||||
class SymbolicTorchFunctionState:
|
class SymbolicTorchFunctionState:
|
||||||
|
|
@ -229,26 +189,9 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||||
return ind + cls.offset
|
return ind + cls.offset
|
||||||
|
|
||||||
|
|
||||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||||
@staticmethod
|
|
||||||
def is_supported_torch_function_mode(ty):
|
|
||||||
# Supported in this sense means we can support graph breaks under the
|
|
||||||
# context.
|
|
||||||
# We are able to trace custom modes but if there are graph breaks under them
|
|
||||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
|
||||||
# same reason we don't handle generic context managers: there may be side effects
|
|
||||||
# that are now affected by executing the funtion across two frames instead of one
|
|
||||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
|
||||||
# DeviceContext (which is used for set_default_device)
|
|
||||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
|
||||||
not class_has_getattribute(ty)
|
|
||||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
|
||||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, value, source=None, **kwargs):
|
def __init__(self, value, source=None, **kwargs):
|
||||||
if value is not None:
|
super().__init__(value, **kwargs)
|
||||||
super().__init__(value, **kwargs)
|
|
||||||
self.value = value
|
self.value = value
|
||||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||||
self.source = source
|
self.source = source
|
||||||
|
|
@ -278,39 +221,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def enter(self, tx):
|
def _call_func(self, tx: "InstructionTranslator", values):
|
||||||
from .torch import TorchInGraphFunctionVariable
|
unimplemented("enter/exit for torch function mode NYI")
|
||||||
|
|
||||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
|
||||||
return ConstantVariable.create(None)
|
|
||||||
|
|
||||||
TorchInGraphFunctionVariable(
|
|
||||||
torch._C._push_on_torch_function_stack
|
|
||||||
).call_function(tx, [self], {})
|
|
||||||
return ConstantVariable.create(None)
|
|
||||||
|
|
||||||
def exit(self, tx: "InstructionTranslator", *args):
|
|
||||||
from .torch import TorchInGraphFunctionVariable
|
|
||||||
|
|
||||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
|
||||||
tx, [], {}
|
|
||||||
)
|
|
||||||
return ConstantVariable.create(None)
|
|
||||||
|
|
||||||
def reconstruct_type(self, codegen):
|
|
||||||
ty = NoEnterTorchFunctionMode
|
|
||||||
codegen(
|
|
||||||
AttrSource(
|
|
||||||
codegen.tx.import_source(ty.__module__),
|
|
||||||
ty.__name__,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def supports_graph_breaks(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def exit_on_graph_break(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_all_args(args, kwargs):
|
def _get_all_args(args, kwargs):
|
||||||
|
|
|
||||||
|
|
@ -417,22 +417,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||||
and self.source
|
and self.source
|
||||||
and not is_forbidden_context_manager(self.value)
|
and not is_forbidden_context_manager(self.value)
|
||||||
):
|
):
|
||||||
from torch.overrides import TorchFunctionMode
|
|
||||||
|
|
||||||
from .ctx_manager import GenericContextWrappingVariable
|
from .ctx_manager import GenericContextWrappingVariable
|
||||||
from .torch_function import TorchFunctionModeVariable
|
|
||||||
|
|
||||||
if issubclass(
|
|
||||||
self.value, TorchFunctionMode
|
|
||||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
|
||||||
self.value
|
|
||||||
):
|
|
||||||
var_cls = TorchFunctionModeVariable
|
|
||||||
else:
|
|
||||||
var_cls = GenericContextWrappingVariable
|
|
||||||
|
|
||||||
cm_obj = tx.output.side_effects.track_object_new(
|
cm_obj = tx.output.side_effects.track_object_new(
|
||||||
self.source, self.value, var_cls, {}
|
self.source, self.value, GenericContextWrappingVariable, {}
|
||||||
)
|
)
|
||||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||||
return cm_obj
|
return cm_obj
|
||||||
|
|
|
||||||
|
|
@ -2537,40 +2537,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
|
||||||
public:
|
public:
|
||||||
TORCH_FUNCTION_MODE_STACK(
|
TORCH_FUNCTION_MODE_STACK(
|
||||||
const py::list& initial_stack,
|
const py::list& initial_stack,
|
||||||
|
const py::list& ignored_types,
|
||||||
py::object verbose_code_parts)
|
py::object verbose_code_parts)
|
||||||
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
|
: LeafGuard(std::move(verbose_code_parts)),
|
||||||
|
_ref_stack(),
|
||||||
|
_ignored_types() {
|
||||||
Py_ssize_t len = PyList_Size(initial_stack.ptr());
|
Py_ssize_t len = PyList_Size(initial_stack.ptr());
|
||||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||||
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
|
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
|
||||||
auto type = Py_TYPE(mode);
|
auto type = Py_TYPE(mode);
|
||||||
this->_ref_stack.push_back(type);
|
this->_ref_stack.push_back(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
len = PyList_Size(ignored_types.ptr());
|
||||||
|
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||||
|
PyObject* type_obj =
|
||||||
|
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
|
||||||
|
if (PyType_Check(type_obj) == 0) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_TypeError, "ignored_types should contain a list of types");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyTypeObject* type = (PyTypeObject*)type_obj;
|
||||||
|
this->_ignored_types.insert(type);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool check_nopybind(PyObject* value) override {
|
bool check_nopybind(PyObject* value) override {
|
||||||
// Ignore value arg, only used to satisfy the interface
|
// Ignore value arg, only used to satisfy the interface
|
||||||
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
|
size_t ref_ind = 0;
|
||||||
|
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
|
||||||
const size_t ref_stack_size = this->_ref_stack.size();
|
const size_t ref_stack_size = this->_ref_stack.size();
|
||||||
|
|
||||||
if (len != ref_stack_size) {
|
int64_t idx = 0;
|
||||||
return false;
|
while ((idx < len) && (ref_ind < ref_stack_size)) {
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t idx = 0; (size_t)idx < len; idx++) {
|
|
||||||
std::shared_ptr<c10::SafePyObject> mode =
|
std::shared_ptr<c10::SafePyObject> mode =
|
||||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||||
|
|
||||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||||
if (mode_type != _ref_stack.at(idx)) {
|
bool act_ignored = this->_ignored_types.count(mode_type) > 0;
|
||||||
|
bool ref_ignored =
|
||||||
|
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
|
||||||
|
// skip ignored types
|
||||||
|
if (act_ignored && ref_ignored) {
|
||||||
|
idx++;
|
||||||
|
ref_ind++;
|
||||||
|
continue;
|
||||||
|
} else if (ref_ignored) {
|
||||||
|
ref_ind++;
|
||||||
|
continue;
|
||||||
|
} else if (act_ignored) {
|
||||||
|
idx++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// if we already have more non-ignored modes than the ref stack
|
||||||
|
// or if the mode doesn't match at the current index, return false
|
||||||
|
else if (mode_type != _ref_stack.at(ref_ind)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ref_ind++;
|
||||||
|
idx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; ref_ind < ref_stack_size; ref_ind++) {
|
||||||
|
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
for (; idx < len; idx++) {
|
||||||
|
std::shared_ptr<c10::SafePyObject> mode =
|
||||||
|
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||||
|
|
||||||
|
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||||
|
if (!(this->_ignored_types.count(mode_type) > 0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ref_ind == ref_stack_size && idx == len;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<PyTypeObject*> _ref_stack;
|
std::vector<PyTypeObject*> _ref_stack;
|
||||||
|
std::set<PyTypeObject*> _ignored_types;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TENSOR_MATCH : public LeafGuard {
|
class TENSOR_MATCH : public LeafGuard {
|
||||||
|
|
@ -3735,7 +3785,7 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||||
LeafGuard,
|
LeafGuard,
|
||||||
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
|
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
|
||||||
py_m, "TORCH_FUNCTION_MODE_STACK")
|
py_m, "TORCH_FUNCTION_MODE_STACK")
|
||||||
.def(py::init<py::list, py::list>())
|
.def(py::init<py::list, py::list, py::list>())
|
||||||
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
|
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
|
||||||
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
|
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
|
||||||
py_m, "DATA_PTR_MATCH")
|
py_m, "DATA_PTR_MATCH")
|
||||||
|
|
@ -3972,9 +4022,10 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||||
"add_torch_function_mode_stack_guard",
|
"add_torch_function_mode_stack_guard",
|
||||||
[](GuardManager& self,
|
[](GuardManager& self,
|
||||||
const py::list& initial_stack,
|
const py::list& initial_stack,
|
||||||
|
const py::list& ignored_types,
|
||||||
py::object verbose_code_parts) -> void {
|
py::object verbose_code_parts) -> void {
|
||||||
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
|
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
|
||||||
initial_stack, std::move(verbose_code_parts)));
|
initial_stack, ignored_types, std::move(verbose_code_parts)));
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"add_data_ptr_guard",
|
"add_data_ptr_guard",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user