mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular: (1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests (2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards. (3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement). Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415 Approved by: https://github.com/ezyang
156 lines
6.2 KiB
Python
156 lines
6.2 KiB
Python
import contextlib
|
|
from typing import Optional
|
|
|
|
import warnings
|
|
import torch
|
|
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
|
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey
|
|
|
|
|
|
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
|
# - We need a better user-facing api for _DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
|
|
class TorchDispatchMode:
|
|
"""
|
|
A ``TorchDispatchMode`` allows you to override the meaning of all
|
|
``__torch_dispatch__`` overrideable functions within a dynamic scope,
|
|
without having to actually create a tensor subclass or manually
|
|
monkey-patch functions in the PyTorch API. Some common situations
|
|
where you should use a mode:
|
|
|
|
* You want to override the meaning of factory functions, or other
|
|
functions that do not otherwise take a tensor as an argument
|
|
(these cannot be overridden with tensor subclasses).
|
|
|
|
* You want to override the behavior of all functions without needing
|
|
to wrap your inputs in tensor subclasses; e.g., if you are just
|
|
interested in logging intermediate computations.
|
|
|
|
* You want to control the order of execution of various tensor
|
|
subclasses explicitly, rather than implicitly via the return of
|
|
``NotImplemented``.
|
|
|
|
Independent subclasses of :class:`TorchDispatchMode` are compositional:
|
|
modes can be pushed onto a stack using ``with MyMode():``.
|
|
When you call functions in the PyTorch API inside your
|
|
``__torch_dispatch__`` implementation, by default, they will forward on to
|
|
the next mode on the mode stack. If you want recursively call back into
|
|
your current ``__torch_dispatch__`` implementation, either explicitly
|
|
invoke ``self.__torch_dispatch__(...)``, or use the context manager
|
|
``__torch_dispatch__(self)`` to make PyTorch
|
|
API self-referential (beware of infinite loops, in this case!)
|
|
"""
|
|
def __init__(self, _dispatch_key=None):
|
|
if _dispatch_key is not None:
|
|
assert isinstance(_dispatch_key, torch._C.DispatchKey)
|
|
self.__dict__['_dispatch_key'] = _dispatch_key
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
_push_mode(self, self.__dict__.get("_dispatch_key", None))
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
_pop_mode(self.__dict__.get("_dispatch_key", None))
|
|
|
|
@classmethod
|
|
def push(cls, *args, **kwargs):
|
|
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
|
instance = cls(*args, **kwargs)
|
|
return instance
|
|
|
|
def _get_current_dispatch_mode():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
|
|
|
|
|
|
def _get_current_dispatch_mode_stack():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
|
|
|
|
def _push_mode(mode, k: Optional[DispatchKey] = None):
|
|
if k is not None:
|
|
from torch._ops import push_mode_for_key, get_cached_ops
|
|
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
|
|
# Clear the cache of every op that has been used so far, for this particular key.
|
|
ks = torch._C._functionality_to_backend_keys(k)
|
|
for op in get_cached_ops():
|
|
for key in ks:
|
|
op._uncache_dispatch(key)
|
|
push_mode_for_key(k, mode)
|
|
else:
|
|
_push_on_torch_dispatch_stack(mode)
|
|
|
|
|
|
def _pop_mode(k: Optional[DispatchKey] = None):
|
|
if k is not None:
|
|
from torch._ops import pop_mode_for_key
|
|
return pop_mode_for_key(k)
|
|
else:
|
|
return _pop_torch_dispatch_stack()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
|
|
old = _pop_mode(k)
|
|
try:
|
|
yield old
|
|
finally:
|
|
_push_mode(old, k)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _disable_current_modes():
|
|
mode_len = _len_torch_dispatch_stack()
|
|
old_modes = [_pop_mode() for _ in range(mode_len)]
|
|
try:
|
|
yield old_modes
|
|
finally:
|
|
for mode in reversed(old_modes):
|
|
_push_mode(mode)
|
|
|
|
|
|
class BaseTorchDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def is_traceable_wrapper_subclass(t):
|
|
"""
|
|
Returns whether or not a tensor subclass that implements __torch_dispatch__
|
|
is 'traceable' with torch.compile.
|
|
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
|
|
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
|
|
It is also expected to obey some restrictions around traceability and aliasing
|
|
(TODO: add clear documentation around this.)
|
|
"""
|
|
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
|
|
return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")
|
|
|
|
def transform_subclass(t, callback):
|
|
"""
|
|
Given a traceable, wrapper tensor subclass ``t`` that implements
|
|
``__torch_dispatch__`` and holds some inner tensors,
|
|
and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
|
|
`transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
|
|
It will do so by grabbing each inner tensor attribute from the wrapper,
|
|
passing them into ``callback`` to get a transformed tensor,
|
|
and putting each transformed tensor into the fresh tensor subclass instance.
|
|
|
|
Note: this function will not handle ensuring that the fresh subclass
|
|
gets the same (autograd, and aliasing) metadata as the original tensor.
|
|
This is generally handled in other subsystems like AOTAutograd.
|
|
"""
|
|
attrs, ctx = t.__tensor_flatten__()
|
|
transformed_tensors_dict = {}
|
|
for attr in attrs:
|
|
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
|
|
return type(t).__tensor_unflatten__(transformed_tensors_dict, ctx)
|