mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Consolidate SymDispatchMode into ProxyTensorMode (#132674)
Instead of having a separate context variable for SymDispatchMode, we now simply delegate to the current active proxy tensor mode when we need to trace a SymInt. We maintain a separate `__sym_dispatch__` magic method as the calling convention is different than `__torch_dispatch__`. Consolidating the modes in this ways means that we can consistently disable both of these modes in tandem simply by removing the mode from the proxy mode infra slot. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674 Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
parent
0f19d4150b
commit
361db32d47
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
|
|
@ -29,7 +29,6 @@
|
|||
- torch/fx/experimental/recording.py
|
||||
- torch/fx/experimental/sym_node.py
|
||||
- torch/fx/experimental/validator.py
|
||||
- torch/fx/experimental/_sym_dispatch_mode.py
|
||||
- torch/fx/experimental/proxy_tensor.py
|
||||
- test/distributed/_tensor/test_dtensor_compile.py
|
||||
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
|
||||
|
|
|
|||
|
|
@ -850,7 +850,6 @@ coverage_ignore_functions = [
|
|||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"make_fx",
|
||||
"maybe_disable_fake_tensor_mode",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
|
|
|
|||
|
|
@ -51,3 +51,17 @@ torch.fx.experimental.symbolic_shapes
|
|||
compute_unbacked_bindings
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
|
||||
torch.fx.experimental.proxy_tensor
|
||||
-------------------------------------
|
||||
|
||||
.. currentmodule:: torch.fx.experimental.proxy_tensor
|
||||
.. automodule:: torch.fx.experimental.proxy_tensor
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
make_fx
|
||||
handle_sym_dispatch
|
||||
get_proxy_mode
|
||||
|
|
|
|||
|
|
@ -1143,7 +1143,6 @@ API Reference
|
|||
.. py:module:: torch.fx.experimental.normalize
|
||||
.. py:module:: torch.fx.experimental.optimization
|
||||
.. py:module:: torch.fx.experimental.partitioner_utils
|
||||
.. py:module:: torch.fx.experimental.proxy_tensor
|
||||
.. py:module:: torch.fx.experimental.recording
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
|
|
|
|||
|
|
@ -112,7 +112,6 @@ class AutogradCompilerInstance:
|
|||
# TODO(jansel): are all these modes needed?
|
||||
self.stack.enter_context(decompose({}))
|
||||
self.stack.enter_context(self.fake_tensor_mode)
|
||||
self.stack.enter_context(self.proxy_mode.sym_mode)
|
||||
self.stack.enter_context(self.proxy_mode)
|
||||
self.stack.enter_context(disable_autocast_cache())
|
||||
self.stack.enter_context(preserve_node_meta())
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from weakref import ReferenceType
|
|||
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.fx.experimental._sym_dispatch_mode
|
||||
from torch._C._dynamo.guards import GlobalStateGuard
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
|
|
@ -1234,9 +1233,7 @@ class CatchErrorsWrapper:
|
|||
frame, cache_entry, self.hooks, frame_state
|
||||
)
|
||||
|
||||
with (
|
||||
compile_lock
|
||||
), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch():
|
||||
with compile_lock, _disable_current_modes():
|
||||
# skip=1: skip this frame
|
||||
return self._torchdynamo_orig_callable(
|
||||
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import List, Optional, Type
|
||||
|
||||
|
||||
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
|
||||
|
||||
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
|
||||
|
||||
|
||||
# SymDispatchMode gets invoked whenever an operation is processed on
|
||||
# a PySymInt. When this occurs, you get called at __sym_dispatch__
|
||||
# with the operation in question. This is symmetric to TorchDispatchMode
|
||||
# but with some caveats:
|
||||
#
|
||||
# - In TorchDispatchMode, you get the same arguments as what a user
|
||||
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
|
||||
# you get (a, b) as args to your call. In SymDispatchMode, if
|
||||
# you call a + b (where a and b are SymInts), you will get
|
||||
# (a.node, b.node) as your args (these are PySymInts)
|
||||
#
|
||||
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
|
||||
# So you have to manually call Tracer/create_node to write into
|
||||
# the graph. See ProxySymDispatchMode for an example
|
||||
#
|
||||
class SymDispatchMode:
|
||||
def __sym_dispatch__(self, func, types, args, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
global SYM_FUNCTION_MODE
|
||||
old = SYM_FUNCTION_MODE
|
||||
if hasattr(self, "inner"):
|
||||
raise RuntimeError(
|
||||
f"{self} has already been used as a mode. Please use a fresh version"
|
||||
)
|
||||
else:
|
||||
self.inner = old
|
||||
SYM_FUNCTION_MODE = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global SYM_FUNCTION_MODE
|
||||
SYM_FUNCTION_MODE = self.inner
|
||||
|
||||
|
||||
def handle_sym_dispatch(func, args, kwargs):
|
||||
global SYM_FUNCTION_MODE
|
||||
mode = sym_function_mode()
|
||||
assert mode
|
||||
SYM_FUNCTION_MODE = mode.inner
|
||||
try:
|
||||
# TODO: properly compute types
|
||||
types: List[Type] = []
|
||||
return mode.__sym_dispatch__(func, types, args, kwargs)
|
||||
finally:
|
||||
SYM_FUNCTION_MODE = mode
|
||||
|
||||
|
||||
def sym_function_mode():
|
||||
return SYM_FUNCTION_MODE
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_sym_dispatch():
|
||||
global SYM_FUNCTION_MODE
|
||||
old = SYM_FUNCTION_MODE
|
||||
SYM_FUNCTION_MODE = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
SYM_FUNCTION_MODE = old
|
||||
|
|
@ -22,13 +22,13 @@ import warnings
|
|||
import weakref
|
||||
|
||||
from ._backward_state import BackwardState
|
||||
from ._sym_dispatch_mode import SymDispatchMode
|
||||
from .sym_node import SymNode
|
||||
from torch.utils._thunk import Thunk
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
|
||||
from dataclasses import dataclass
|
||||
from torch import SymInt, SymBool, Tensor
|
||||
import torch._ops
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
|
||||
|
|
@ -59,7 +59,10 @@ if TYPE_CHECKING:
|
|||
from torch.fx._symbolic_trace import PHBase
|
||||
from torch.types import IntLikeType
|
||||
|
||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
|
||||
__all__ = [
|
||||
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
|
||||
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
|
||||
]
|
||||
|
||||
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
||||
|
||||
|
|
@ -1006,7 +1009,10 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
_managers: List[AbstractContextManager]
|
||||
# Ensure this is read-only; this exists only for legacy reasons
|
||||
@property
|
||||
def enable_tracing(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -1020,12 +1026,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
super().__init__(dk)
|
||||
self.tracer = tracer
|
||||
self.tracing_mode = tracing_mode
|
||||
self.enable_tracing = True
|
||||
self.pre_dispatch = pre_dispatch
|
||||
self._allow_fake_constant = _allow_fake_constant
|
||||
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
|
||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
||||
self._managers = []
|
||||
# Indicates to our torch_dispatch dispatching infra that
|
||||
# this is an "infra" mode with lower dispatching precedence.
|
||||
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
|
||||
|
|
@ -1045,14 +1048,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
args: Tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None
|
||||
) -> object:
|
||||
with self.sym_mode.enable(False), set_original_aten_op(func):
|
||||
with set_original_aten_op(func):
|
||||
return self.inner_torch_dispatch(func, types, args, kwargs)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
# sym mode first, then us...
|
||||
m = self.sym_mode.enable(True)
|
||||
self._managers.append(m)
|
||||
m.__enter__()
|
||||
# Stash and store the previous proxy mode (there may or may not be one)
|
||||
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||
self.enter_stack.append(maybe_prev_proxy_mode)
|
||||
|
|
@ -1064,8 +1063,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[types.TracebackType]
|
||||
) -> Optional[bool]:
|
||||
m = self._managers.pop()
|
||||
# ...exit us first, then sym mode
|
||||
b = super().__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
# Re-enable the previous proxy mode, if there was one.
|
||||
|
|
@ -1073,11 +1070,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
if mb_previous_proxy_mode is not None:
|
||||
_push_mode(mb_previous_proxy_mode)
|
||||
|
||||
if not b:
|
||||
return m.__exit__(exc_type, exc_value, traceback)
|
||||
else:
|
||||
return m.__exit__(None, None, None)
|
||||
|
||||
return b
|
||||
|
||||
def inner_torch_dispatch(
|
||||
self,
|
||||
|
|
@ -1088,9 +1081,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
) -> object:
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if not self.enable_tracing:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if func in (prim.device.default,):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
|
@ -1100,25 +1090,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
def is_infra_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ProxySymDispatchMode(SymDispatchMode):
|
||||
def __init__(self, tracer: _ProxyTracer) -> None:
|
||||
super().__init__()
|
||||
self.tracer = tracer
|
||||
# When false, we don't trace operations. If you do this, you MUST
|
||||
# call track_tensor/track_tensor_tree on all results of the operation
|
||||
# to ensure we can adequately track the results
|
||||
self.enable_tracing = True
|
||||
|
||||
@contextmanager
|
||||
def enable(self, b: bool) -> Generator[None, None, None]:
|
||||
old = self.enable_tracing
|
||||
self.enable_tracing = b
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_tracing = old
|
||||
|
||||
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
|
||||
n_args = tuple(
|
||||
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
|
||||
|
|
@ -1139,9 +1110,6 @@ class ProxySymDispatchMode(SymDispatchMode):
|
|||
args: Tuple[object, ...],
|
||||
kwargs: Dict[str, object]
|
||||
) -> object:
|
||||
if not self.enable_tracing:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Peephole optimize multiply by one
|
||||
# NB: be careful not to trigger guards here!
|
||||
if func == operator.mul:
|
||||
|
|
@ -1727,7 +1695,6 @@ class _MakefxTracer:
|
|||
stack.enter_context(self.fake_tensor_mode)
|
||||
stack.enter_context(self.python_dispatcher_mode)
|
||||
stack.enter_context(self.proxy_function_mode)
|
||||
stack.enter_context(proxy_mode.sym_mode)
|
||||
stack.enter_context(self.torch_fn_metadata_mode)
|
||||
stack.enter_context(proxy_mode)
|
||||
stack.enter_context(disable_autocast_cache())
|
||||
|
|
@ -1787,8 +1754,13 @@ def make_fx(
|
|||
_allow_fake_constant: bool = False,
|
||||
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
|
||||
|
||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||
"""
|
||||
Given a function f, return a new function which when executed with valid
|
||||
arguments to f, returns an FX GraphModule representing the set of operations that
|
||||
were executed during the course of execution.
|
||||
"""
|
||||
|
||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||
|
||||
make_fx_tracer = _MakefxTracer(
|
||||
decomposition_table,
|
||||
|
|
@ -1810,8 +1782,38 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
|
|||
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
||||
|
||||
|
||||
def get_innermost_proxy_mode() -> ProxyTorchDispatchMode:
|
||||
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||
# TODO: this is a legacy name, there is only ever one proxy mode as it's an
|
||||
# infra mode
|
||||
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
||||
return get_proxy_mode()
|
||||
|
||||
|
||||
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
||||
"""
|
||||
Current the currently active proxy tracing mode, or None if
|
||||
we are not currently tracing. This includes pre-dispatch proxy
|
||||
tracing.
|
||||
"""
|
||||
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
|
||||
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||
assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
|
||||
return pre_dispatch_mode or mode
|
||||
|
||||
|
||||
def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
|
||||
"""
|
||||
Call into the currently active proxy tracing mode to do a
|
||||
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
|
||||
these arguments.
|
||||
"""
|
||||
mode = get_proxy_mode()
|
||||
assert mode
|
||||
# Have to do it manually, because we're not doing the normal torch
|
||||
# dispatch machinery which disables it for us
|
||||
with disable_proxy_modes_tracing():
|
||||
# TODO: properly compute types
|
||||
types: List[Type] = []
|
||||
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
|
|
@ -31,10 +31,6 @@ from torch import ( # noqa: F401
|
|||
SymFloat,
|
||||
SymInt,
|
||||
)
|
||||
from torch.fx.experimental._sym_dispatch_mode import (
|
||||
handle_sym_dispatch,
|
||||
sym_function_mode,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
|
|||
method_attr = method
|
||||
|
||||
def binary_magic_impl(self, other):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = method_to_operator(method)
|
||||
|
|
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
|
|||
if alternate_impl and out_hint is not None:
|
||||
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
|
||||
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||
)
|
||||
|
|
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
|
|||
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
||||
|
||||
def unary_magic_impl(self):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = method_to_operator(method)
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
|
||||
# TODO: consider constant prop here
|
||||
expr = self.expr
|
||||
|
|
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
|
|||
elif method == "sym_ite":
|
||||
|
||||
def sym_ite_impl(pred_node, then_node, else_node):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
pred_node,
|
||||
handle_sym_dispatch(
|
||||
|
|
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
|
|||
elif method == "round":
|
||||
|
||||
def round_impl(self, ndigits=None):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = builtins.round
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
|
||||
)
|
||||
|
|
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
|
|||
# NB: don't LRU cache, lots of arguments
|
||||
|
||||
def sizes_strides_impl(self, sizes, strides):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
|
||||
op = getattr(sys.modules[__name__], method)
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self,
|
||||
handle_sym_dispatch(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user