import contextlib from typing import Optional import warnings import torch from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: # - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) class TorchDispatchMode: """ A ``TorchDispatchMode`` allows you to override the meaning of all ``__torch_dispatch__`` overrideable functions within a dynamic scope, without having to actually create a tensor subclass or manually monkey-patch functions in the PyTorch API. Some common situations where you should use a mode: * You want to override the meaning of factory functions, or other functions that do not otherwise take a tensor as an argument (these cannot be overridden with tensor subclasses). * You want to override the behavior of all functions without needing to wrap your inputs in tensor subclasses; e.g., if you are just interested in logging intermediate computations. * You want to control the order of execution of various tensor subclasses explicitly, rather than implicitly via the return of ``NotImplemented``. Independent subclasses of :class:`TorchDispatchMode` are compositional: modes can be pushed onto a stack using ``with MyMode():``. When you call functions in the PyTorch API inside your ``__torch_dispatch__`` implementation, by default, they will forward on to the next mode on the mode stack. If you want recursively call back into your current ``__torch_dispatch__`` implementation, either explicitly invoke ``self.__torch_dispatch__(...)``, or use the context manager ``__torch_dispatch__(self)`` to make PyTorch API self-referential (beware of infinite loops, in this case!) """ def __init__(self, _dispatch_key=None): if _dispatch_key is not None: assert isinstance(_dispatch_key, torch._C.DispatchKey) self.__dict__['_dispatch_key'] = _dispatch_key def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError() def __enter__(self): _push_mode(self, self.__dict__.get("_dispatch_key", None)) return self def __exit__(self, exc_type, exc_val, exc_tb): _pop_mode(self.__dict__.get("_dispatch_key", None)) @classmethod def push(cls, *args, **kwargs): warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") instance = cls(*args, **kwargs) return instance def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None def _get_current_dispatch_mode_stack(): stack_len = _len_torch_dispatch_stack() return [_get_dispatch_stack_at(i) for i in range(stack_len)] def _push_mode(mode, k: Optional[DispatchKey] = None): if k is not None: from torch._ops import push_mode_for_key, get_cached_ops # See Note [Not Caching Per-Dispatch-Key Mode Handlers] # Clear the cache of every op that has been used so far, for this particular key. for op in get_cached_ops(): op._uncache_dispatch(k) push_mode_for_key(k, mode) # Note [Per-Dispatch-Key Modes Must Be Reentrant] # The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but: # (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough # kernels registered to any dispatch key, so we use the Python mode stack as a catchall, # to guarantee that every op will be seen by our mode. # (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode # at both the Autograd key and the Python key, nothing bad should happen. # The main use case for this is pre-autograd tracing with TorchProxyDispatchMode. _push_on_torch_dispatch_stack(mode) def _pop_mode(k: Optional[DispatchKey] = None): m = _pop_torch_dispatch_stack() if k is not None: from torch._ops import pop_mode_for_key tmp = pop_mode_for_key(k) assert m is tmp return m @contextlib.contextmanager def _pop_mode_temporarily(k: Optional[DispatchKey] = None): old = _pop_mode(k) try: yield old finally: _push_mode(old, k) @contextlib.contextmanager def _disable_current_modes(): mode_len = _len_torch_dispatch_stack() old_modes = [_pop_mode() for _ in range(mode_len)] try: yield old_modes finally: for mode in reversed(old_modes): _push_mode(mode) class BaseTorchDispatchMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs)