mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/73684 Approved by: albanD
35 lines
1.6 KiB
Python
35 lines
1.6 KiB
Python
import torch
|
|
import contextlib
|
|
from typing import Iterator
|
|
|
|
# Context manager that causes all pytorch operators to dispatch to the passed-in
|
|
# type's __torch_dispatch__ function.
|
|
# operation that accepts no tensors but returns a tensor.
|
|
#
|
|
# enable_python_mode is affected by torch._C._DisableTorchDispatch.
|
|
#
|
|
# NB: Calling an operator inside __torch_dispatch__ does go through
|
|
# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
|
|
# __torch_dispatch__ to prevent infinite recursion.
|
|
#
|
|
# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
|
|
# - it currently cannot be nested. This should be simple to implement; we need a
|
|
# stack of TorchDispatchTypeObjects and the next bullet point.
|
|
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
@contextlib.contextmanager
|
|
def enable_python_mode(cls) -> Iterator[None]:
|
|
if cls.__torch_dispatch__ is torch.Tensor.__torch_dispatch__:
|
|
raise ValueError('The class passed to enable_python_mode '
|
|
'must have a non-default __torch_dispatch__ classmethod')
|
|
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
|
|
raise ValueError('The argument passed to enable_python_mode '
|
|
'must be the type of a Tensor subclass')
|
|
torch._C._enter_python_mode(cls)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._C._exit_python_mode()
|