mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
108 lines
4.3 KiB
Python
108 lines
4.3 KiB
Python
import contextlib
|
|
|
|
import warnings
|
|
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
|
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, _set_torch_dispatch_mode
|
|
|
|
|
|
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
|
# - We need a better user-facing api for _DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
|
|
class TorchDispatchMode:
|
|
"""
|
|
A ``TorchDispatchMode`` allows you to override the meaning of all
|
|
``__torch_dispatch__`` overrideable functions within a dynamic scope,
|
|
without having to actually create a tensor subclass or manually
|
|
monkey-patch functions in the PyTorch API. Some common situations
|
|
where you should use a mode:
|
|
|
|
* You want to override the meaning of factory functions, or other
|
|
functions that do not otherwise take a tensor as an argument
|
|
(these cannot be overridden with tensor subclasses).
|
|
|
|
* You want to override the behavior of all functions without needing
|
|
to wrap your inputs in tensor subclasses; e.g., if you are just
|
|
interested in logging intermediate computations.
|
|
|
|
* You want to control the order of execution of various tensor
|
|
subclasses explicitly, rather than implicitly via the return of
|
|
``NotImplemented``.
|
|
|
|
Independent subclasses of :class:`TorchDispatchMode` are compositional:
|
|
modes can be pushed onto a stack using ``with MyMode():``.
|
|
When you call functions in the PyTorch API inside your
|
|
``__torch_dispatch__`` implementation, by default, they will forward on to
|
|
the next mode on the mode stack. If you want recursively call back into
|
|
your current ``__torch_dispatch__`` implementation, either explicitly
|
|
invoke ``self.__torch_dispatch__(...)``, or use the context manager
|
|
``__torch_dispatch__(self)`` to make PyTorch
|
|
API self-referential (beware of infinite loops, in this case!)
|
|
"""
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
_push_mode(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
_pop_mode()
|
|
|
|
@classmethod
|
|
def push(cls, *args, **kwargs):
|
|
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
|
instance = cls(*args, **kwargs)
|
|
return instance
|
|
|
|
def _get_current_dispatch_mode():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
|
|
|
|
|
|
def _get_current_dispatch_mode_stack():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
|
|
|
|
def _push_mode(mode):
|
|
if _len_torch_dispatch_stack() == 0:
|
|
_set_torch_dispatch_mode(_TorchDispatchStackMode())
|
|
_push_on_torch_dispatch_stack(mode)
|
|
|
|
|
|
def _pop_mode():
|
|
old = _pop_torch_dispatch_stack()
|
|
if _len_torch_dispatch_stack() == 0:
|
|
_set_torch_dispatch_mode(None)
|
|
return old
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _pop_mode_temporarily():
|
|
old = _pop_mode()
|
|
try:
|
|
yield old
|
|
finally:
|
|
_push_mode(old)
|
|
|
|
# a helper "mode" used by the torch dispatch push helper method. This is the only mode that will ever
|
|
# be active at the C++ level and it will run the current mode
|
|
class _TorchDispatchStackMode:
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
with _pop_mode_temporarily() as old:
|
|
if _len_torch_dispatch_stack() > 0:
|
|
_set_torch_dispatch_mode(self)
|
|
# we can't check the type of __torch_dispatch__ here but this is sufficient for checking it's a classmethod
|
|
if old.__torch_dispatch__.__self__ is type(old):
|
|
raise RuntimeError(f"{type(old)}'s torch_dispatch function " +
|
|
"should be a normal method not a class method")
|
|
return old.__torch_dispatch__(func, types, args, kwargs)
|
|
|
|
class BaseTorchDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|