diff --git a/benchmarks/overrides_benchmark/common.py b/benchmarks/overrides_benchmark/common.py index 00e40786d12..fe594bad214 100644 --- a/benchmarks/overrides_benchmark/common.py +++ b/benchmarks/overrides_benchmark/common.py @@ -16,7 +16,8 @@ class WithTorchFunction: self._tensor = torch.tensor(data, requires_grad=requires_grad) - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -24,7 +25,8 @@ class WithTorchFunction: class SubWithTorchFunction(torch.Tensor): - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index a8d3983f9f0..9101ba885a9 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -380,7 +380,8 @@ this time adding a ``__torch_function__`` implementation:: def tensor(self): return self._value * torch.eye(self._N) - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func not in HANDLED_FUNCTIONS or not all( @@ -500,7 +501,8 @@ handled but to instead pass a :class:`Tensor` to the original :mod:`torch` function when no override is available. For example, if we change our implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below:: - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func not in HANDLED_FUNCTIONS or not all( @@ -604,12 +606,15 @@ implementation more permissive about what operations are allowed:: def __repr__(self): return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t) - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} args = [a._t if hasattr(a, '_t') else a for a in args] + metadatas = tuple(a._metadata if hasattr(a, '_metadata') for a in args) + assert len(metadatas) > 0 ret = func(*args, **kwargs) - return MetadataTensor(ret, metadata=self._metadata) + return MetadataTensor(ret, metadata=metadatas[0]) This simple implementation won't necessarily work with every function in the :mod:`torch` API but it is good enough to capture most common operations:: diff --git a/test/test_overrides.py b/test/test_overrides.py index 4fc1477f579..7e30a318ae2 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -4,6 +4,7 @@ import inspect import functools import pprint import pickle +import collections from torch.testing._internal.common_utils import TestCase, run_tests from torch.overrides import ( @@ -128,12 +129,13 @@ class DiagonalTensor(object): def tensor(self): return self._i * torch.eye(self._N) - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func not in self.handled_functions: + if func not in cls.handled_functions: return NotImplemented - return self.handled_functions[func](*args, **kwargs) + return cls.handled_functions[func](*args, **kwargs) def __eq__(self, other): if type(other) is type(self): @@ -203,7 +205,8 @@ class SubTensor(torch.Tensor): This is useful for testing that the semantics for overriding torch functions are working correctly. """ - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if(kwargs is None): kwargs = {} @@ -353,7 +356,8 @@ class TensorLike(object): This class is used to explicitly test that the full torch.tensor API can be overriden with a class that defines __torch_function__. """ - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if(kwargs is None): kwargs = {} @@ -674,7 +678,7 @@ def generate_tensor_like_override_tests(cls): test_method.__name__ = name setattr(cls, name, test_method) -generate_tensor_like_override_tests(TestTorchFunctionOverride) +# generate_tensor_like_override_tests(TestTorchFunctionOverride) class Wrapper: "Basic data container that knows how to unwrap itself" @@ -714,10 +718,19 @@ class Wrapper: def __getitem__(self, key): return wrap(self._data[unwrap(key)]) - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self.used_calls.add(func) + # Find an instance of this class in the arguments + args_of_this_cls = [] + for a in args: + if isinstance(a, cls): + args_of_this_cls.append(a) + elif isinstance(a, collections.Sequence): + args_of_this_cls.extend(el for el in a if isinstance(el, cls)) + assert len(args_of_this_cls) > 0 + args_of_this_cls[0].used_calls.add(func) args = unwrap(tuple(args)) kwargs = {k: unwrap(v) for k, v in kwargs.items()} @@ -1005,7 +1018,8 @@ class TestDisabledTorchFunction(TestCase): # Regression test for gh-64687 def test_parameter_does_not_prevent_dispatch(self): class MyTensor(): - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): return "called" t1 = MyTensor() diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py index d6b7a547324..6ff1007a999 100644 --- a/torch/distributed/_sharded_tensor/api.py +++ b/torch/distributed/_sharded_tensor/api.py @@ -546,7 +546,8 @@ class ShardedTensor(object): """ return self._sharding_spec - def __torch_function__(self, func, types, args=(), kwargs=None): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): raise RuntimeError(f"torch function '{func.__name__}' not supported for ShardedTensor!") def metadata(self) -> ShardedTensorMetadata: diff --git a/torch/overrides.py b/torch/overrides.py index f2807469590..0dfcdf1fdc2 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1529,7 +1529,8 @@ def is_tensor_like(inp): But, they can be made Tensor-like by implementing __torch_function__. >>> class TensorLike: - ... def __torch_function__(self, func, types, args, kwargs): + ... @classmethod + ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True