Consolidate SymDispatchMode into ProxyTensorMode (#132674)

Instead of having a separate context variable for SymDispatchMode, we
now simply delegate to the current active proxy tensor mode when we
need to trace a SymInt.  We maintain a separate `__sym_dispatch__` magic
method as the calling convention is different than `__torch_dispatch__`.

Consolidating the modes in this ways means that we can consistently
disable both of these modes in tandem simply by removing the mode
from the proxy mode infra slot.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang 2024-08-08 04:59:11 -07:00 committed by PyTorch MergeBot
parent 0f19d4150b
commit 361db32d47
9 changed files with 90 additions and 136 deletions

1
.github/labeler.yml vendored
View File

@ -29,7 +29,6 @@
- torch/fx/experimental/recording.py
- torch/fx/experimental/sym_node.py
- torch/fx/experimental/validator.py
- torch/fx/experimental/_sym_dispatch_mode.py
- torch/fx/experimental/proxy_tensor.py
- test/distributed/_tensor/test_dtensor_compile.py
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py

View File

@ -850,7 +850,6 @@ coverage_ignore_functions = [
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"make_fx",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",

View File

@ -51,3 +51,17 @@ torch.fx.experimental.symbolic_shapes
compute_unbacked_bindings
rebind_unbacked
resolve_unbacked_bindings
torch.fx.experimental.proxy_tensor
-------------------------------------
.. currentmodule:: torch.fx.experimental.proxy_tensor
.. automodule:: torch.fx.experimental.proxy_tensor
.. autosummary::
:toctree: generated
:nosignatures:
make_fx
handle_sym_dispatch
get_proxy_mode

View File

@ -1143,7 +1143,6 @@ API Reference
.. py:module:: torch.fx.experimental.normalize
.. py:module:: torch.fx.experimental.optimization
.. py:module:: torch.fx.experimental.partitioner_utils
.. py:module:: torch.fx.experimental.proxy_tensor
.. py:module:: torch.fx.experimental.recording
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter

View File

@ -112,7 +112,6 @@ class AutogradCompilerInstance:
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
self.stack.enter_context(self.fake_tensor_mode)
self.stack.enter_context(self.proxy_mode.sym_mode)
self.stack.enter_context(self.proxy_mode)
self.stack.enter_context(disable_autocast_cache())
self.stack.enter_context(preserve_node_meta())

View File

@ -25,7 +25,6 @@ from weakref import ReferenceType
import torch
import torch._logging
import torch.fx.experimental._sym_dispatch_mode
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing
@ -1234,9 +1233,7 @@ class CatchErrorsWrapper:
frame, cache_entry, self.hooks, frame_state
)
with (
compile_lock
), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch():
with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1

View File

@ -1,72 +0,0 @@
# mypy: allow-untyped-defs
import contextlib
from typing import List, Optional, Type
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.node, b.node) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(
f"{self} has already been used as a mode. Please use a fresh version"
)
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = sym_function_mode()
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def sym_function_mode():
return SYM_FUNCTION_MODE
@contextlib.contextmanager
def disable_sym_dispatch():
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = None
try:
yield
finally:
SYM_FUNCTION_MODE = old

View File

@ -22,13 +22,13 @@ import warnings
import weakref
from ._backward_state import BackwardState
from ._sym_dispatch_mode import SymDispatchMode
from .sym_node import SymNode
from torch.utils._thunk import Thunk
from collections import defaultdict
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
from dataclasses import dataclass
from torch import SymInt, SymBool, Tensor
import torch._ops
from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
@ -59,7 +59,10 @@ if TYPE_CHECKING:
from torch.fx._symbolic_trace import PHBase
from torch.types import IntLikeType
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
__all__ = [
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
]
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
@ -1006,7 +1009,10 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
class ProxyTorchDispatchMode(TorchDispatchMode):
_managers: List[AbstractContextManager]
# Ensure this is read-only; this exists only for legacy reasons
@property
def enable_tracing(self) -> bool:
return True
def __init__(
self,
@ -1020,12 +1026,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
super().__init__(dk)
self.tracer = tracer
self.tracing_mode = tracing_mode
self.enable_tracing = True
self.pre_dispatch = pre_dispatch
self._allow_fake_constant = _allow_fake_constant
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
self.sym_mode = ProxySymDispatchMode(tracer)
self._managers = []
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
@ -1045,14 +1048,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None
) -> object:
with self.sym_mode.enable(False), set_original_aten_op(func):
with set_original_aten_op(func):
return self.inner_torch_dispatch(func, types, args, kwargs)
def __enter__(self) -> Self:
# sym mode first, then us...
m = self.sym_mode.enable(True)
self._managers.append(m)
m.__enter__()
# Stash and store the previous proxy mode (there may or may not be one)
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
self.enter_stack.append(maybe_prev_proxy_mode)
@ -1064,8 +1063,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType]
) -> Optional[bool]:
m = self._managers.pop()
# ...exit us first, then sym mode
b = super().__exit__(exc_type, exc_value, traceback)
# Re-enable the previous proxy mode, if there was one.
@ -1073,11 +1070,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if mb_previous_proxy_mode is not None:
_push_mode(mb_previous_proxy_mode)
if not b:
return m.__exit__(exc_type, exc_value, traceback)
else:
return m.__exit__(None, None, None)
return b
def inner_torch_dispatch(
self,
@ -1088,9 +1081,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
) -> object:
kwargs = kwargs or {}
if not self.enable_tracing:
return func(*args, **kwargs)
if func in (prim.device.default,):
return func(*args, **kwargs)
@ -1100,25 +1090,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def is_infra_mode(cls) -> bool:
return True
class ProxySymDispatchMode(SymDispatchMode):
def __init__(self, tracer: _ProxyTracer) -> None:
super().__init__()
self.tracer = tracer
# When false, we don't trace operations. If you do this, you MUST
# call track_tensor/track_tensor_tree on all results of the operation
# to ensure we can adequately track the results
self.enable_tracing = True
@contextmanager
def enable(self, b: bool) -> Generator[None, None, None]:
old = self.enable_tracing
self.enable_tracing = b
try:
yield
finally:
self.enable_tracing = old
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
n_args = tuple(
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
@ -1139,9 +1110,6 @@ class ProxySymDispatchMode(SymDispatchMode):
args: Tuple[object, ...],
kwargs: Dict[str, object]
) -> object:
if not self.enable_tracing:
return func(*args, **kwargs)
# Peephole optimize multiply by one
# NB: be careful not to trigger guards here!
if func == operator.mul:
@ -1727,7 +1695,6 @@ class _MakefxTracer:
stack.enter_context(self.fake_tensor_mode)
stack.enter_context(self.python_dispatcher_mode)
stack.enter_context(self.proxy_function_mode)
stack.enter_context(proxy_mode.sym_mode)
stack.enter_context(self.torch_fn_metadata_mode)
stack.enter_context(proxy_mode)
stack.enter_context(disable_autocast_cache())
@ -1787,8 +1754,13 @@ def make_fx(
_allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
assert tracing_mode in ["real", "fake", "symbolic"]
"""
Given a function f, return a new function which when executed with valid
arguments to f, returns an FX GraphModule representing the set of operations that
were executed during the course of execution.
"""
assert tracing_mode in ["real", "fake", "symbolic"]
make_fx_tracer = _MakefxTracer(
decomposition_table,
@ -1810,8 +1782,38 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
def get_innermost_proxy_mode() -> ProxyTorchDispatchMode:
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
# TODO: this is a legacy name, there is only ever one proxy mode as it's an
# infra mode
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
return get_proxy_mode()
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
"""
Current the currently active proxy tracing mode, or None if
we are not currently tracing. This includes pre-dispatch proxy
tracing.
"""
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
return pre_dispatch_mode or mode
def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
"""
Call into the currently active proxy tracing mode to do a
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
these arguments.
"""
mode = get_proxy_mode()
assert mode
# Have to do it manually, because we're not doing the normal torch
# dispatch machinery which disables it for us
with disable_proxy_modes_tracing():
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
@contextmanager

View File

@ -31,10 +31,6 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode():
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
if sym_function_mode():
if get_proxy_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint
if sym_function_mode():
if get_proxy_mode():
return to_node(
pred_node,
handle_sym_dispatch(
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round
if sym_function_mode():
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = getattr(sys.modules[__name__], method)
if sym_function_mode():
if get_proxy_mode():
return to_node(
self,
handle_sym_dispatch(