mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The bug was that: if you want to move a mode to the autograd key, we need to use the "functionality" key for it (AutogradFunctionality). But when we do that, we need to clear any PythonDispatcher caches for every op for **every** autograd key (since you could run autograd ops with both cpu and cuda tensors underneath the mode, which both may have been cached). I didn't add a test, since this ends up getting indirectly tests by export in the PR. If someone would prefer a direct test I can add one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98030 Approved by: https://github.com/ezyang
131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
import contextlib
|
|
from typing import Optional
|
|
|
|
import warnings
|
|
import torch
|
|
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
|
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey
|
|
|
|
|
|
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
|
# - We need a better user-facing api for _DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
|
|
class TorchDispatchMode:
|
|
"""
|
|
A ``TorchDispatchMode`` allows you to override the meaning of all
|
|
``__torch_dispatch__`` overrideable functions within a dynamic scope,
|
|
without having to actually create a tensor subclass or manually
|
|
monkey-patch functions in the PyTorch API. Some common situations
|
|
where you should use a mode:
|
|
|
|
* You want to override the meaning of factory functions, or other
|
|
functions that do not otherwise take a tensor as an argument
|
|
(these cannot be overridden with tensor subclasses).
|
|
|
|
* You want to override the behavior of all functions without needing
|
|
to wrap your inputs in tensor subclasses; e.g., if you are just
|
|
interested in logging intermediate computations.
|
|
|
|
* You want to control the order of execution of various tensor
|
|
subclasses explicitly, rather than implicitly via the return of
|
|
``NotImplemented``.
|
|
|
|
Independent subclasses of :class:`TorchDispatchMode` are compositional:
|
|
modes can be pushed onto a stack using ``with MyMode():``.
|
|
When you call functions in the PyTorch API inside your
|
|
``__torch_dispatch__`` implementation, by default, they will forward on to
|
|
the next mode on the mode stack. If you want recursively call back into
|
|
your current ``__torch_dispatch__`` implementation, either explicitly
|
|
invoke ``self.__torch_dispatch__(...)``, or use the context manager
|
|
``__torch_dispatch__(self)`` to make PyTorch
|
|
API self-referential (beware of infinite loops, in this case!)
|
|
"""
|
|
def __init__(self, _dispatch_key=None):
|
|
if _dispatch_key is not None:
|
|
assert isinstance(_dispatch_key, torch._C.DispatchKey)
|
|
self.__dict__['_dispatch_key'] = _dispatch_key
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
_push_mode(self, self.__dict__.get("_dispatch_key", None))
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
_pop_mode(self.__dict__.get("_dispatch_key", None))
|
|
|
|
@classmethod
|
|
def push(cls, *args, **kwargs):
|
|
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
|
instance = cls(*args, **kwargs)
|
|
return instance
|
|
|
|
def _get_current_dispatch_mode():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
|
|
|
|
|
|
def _get_current_dispatch_mode_stack():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
|
|
|
|
def _push_mode(mode, k: Optional[DispatchKey] = None):
|
|
if k is not None:
|
|
from torch._ops import push_mode_for_key, get_cached_ops
|
|
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
|
|
# Clear the cache of every op that has been used so far, for this particular key.
|
|
ks = torch._C._functionality_to_backend_keys(k)
|
|
for op in get_cached_ops():
|
|
for key in ks:
|
|
op._uncache_dispatch(key)
|
|
push_mode_for_key(k, mode)
|
|
# Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
|
# The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but:
|
|
# (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough
|
|
# kernels registered to any dispatch key, so we use the Python mode stack as a catchall,
|
|
# to guarantee that every op will be seen by our mode.
|
|
# (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode
|
|
# at both the Autograd key and the Python key, nothing bad should happen.
|
|
# The main use case for this is pre-autograd tracing with TorchProxyDispatchMode.
|
|
_push_on_torch_dispatch_stack(mode)
|
|
|
|
|
|
def _pop_mode(k: Optional[DispatchKey] = None):
|
|
m = _pop_torch_dispatch_stack()
|
|
if k is not None:
|
|
from torch._ops import pop_mode_for_key
|
|
tmp = pop_mode_for_key(k)
|
|
assert m is tmp
|
|
return m
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
|
|
old = _pop_mode(k)
|
|
try:
|
|
yield old
|
|
finally:
|
|
_push_mode(old, k)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _disable_current_modes():
|
|
mode_len = _len_torch_dispatch_stack()
|
|
old_modes = [_pop_mode() for _ in range(mode_len)]
|
|
try:
|
|
yield old_modes
|
|
finally:
|
|
for mode in reversed(old_modes):
|
|
_push_mode(mode)
|
|
|
|
|
|
class BaseTorchDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|