pytorch/torch/utils/_python_dispatch.py
Brian Hirsh 5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00

156 lines
6.2 KiB
Python

import contextlib
from typing import Optional
import warnings
import torch
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
without having to actually create a tensor subclass or manually
monkey-patch functions in the PyTorch API. Some common situations
where you should use a mode:
* You want to override the meaning of factory functions, or other
functions that do not otherwise take a tensor as an argument
(these cannot be overridden with tensor subclasses).
* You want to override the behavior of all functions without needing
to wrap your inputs in tensor subclasses; e.g., if you are just
interested in logging intermediate computations.
* You want to control the order of execution of various tensor
subclasses explicitly, rather than implicitly via the return of
``NotImplemented``.
Independent subclasses of :class:`TorchDispatchMode` are compositional:
modes can be pushed onto a stack using ``with MyMode():``.
When you call functions in the PyTorch API inside your
``__torch_dispatch__`` implementation, by default, they will forward on to
the next mode on the mode stack. If you want recursively call back into
your current ``__torch_dispatch__`` implementation, either explicitly
invoke ``self.__torch_dispatch__(...)``, or use the context manager
``__torch_dispatch__(self)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
def __init__(self, _dispatch_key=None):
if _dispatch_key is not None:
assert isinstance(_dispatch_key, torch._C.DispatchKey)
self.__dict__['_dispatch_key'] = _dispatch_key
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
def __enter__(self):
_push_mode(self, self.__dict__.get("_dispatch_key", None))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_pop_mode(self.__dict__.get("_dispatch_key", None))
@classmethod
def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
instance = cls(*args, **kwargs)
return instance
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
def _get_current_dispatch_mode_stack():
stack_len = _len_torch_dispatch_stack()
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode, k: Optional[DispatchKey] = None):
if k is not None:
from torch._ops import push_mode_for_key, get_cached_ops
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
# Clear the cache of every op that has been used so far, for this particular key.
ks = torch._C._functionality_to_backend_keys(k)
for op in get_cached_ops():
for key in ks:
op._uncache_dispatch(key)
push_mode_for_key(k, mode)
else:
_push_on_torch_dispatch_stack(mode)
def _pop_mode(k: Optional[DispatchKey] = None):
if k is not None:
from torch._ops import pop_mode_for_key
return pop_mode_for_key(k)
else:
return _pop_torch_dispatch_stack()
@contextlib.contextmanager
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
old = _pop_mode(k)
try:
yield old
finally:
_push_mode(old, k)
@contextlib.contextmanager
def _disable_current_modes():
mode_len = _len_torch_dispatch_stack()
old_modes = [_pop_mode() for _ in range(mode_len)]
try:
yield old_modes
finally:
for mode in reversed(old_modes):
_push_mode(mode)
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
def is_traceable_wrapper_subclass(t):
"""
Returns whether or not a tensor subclass that implements __torch_dispatch__
is 'traceable' with torch.compile.
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
It is also expected to obey some restrictions around traceability and aliasing
(TODO: add clear documentation around this.)
"""
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")
def transform_subclass(t, callback):
"""
Given a traceable, wrapper tensor subclass ``t`` that implements
``__torch_dispatch__`` and holds some inner tensors,
and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
`transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
It will do so by grabbing each inner tensor attribute from the wrapper,
passing them into ``callback`` to get a transformed tensor,
and putting each transformed tensor into the fresh tensor subclass instance.
Note: this function will not handle ensuring that the fresh subclass
gets the same (autograd, and aliasing) metadata as the original tensor.
This is generally handled in other subsystems like AOTAutograd.
"""
attrs, ctx = t.__tensor_flatten__()
transformed_tensors_dict = {}
for attr in attrs:
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
return type(t).__tensor_unflatten__(transformed_tensors_dict, ctx)