Use classmethods for overrides (#64841)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64841

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D30991424

Pulled By: albanD

fbshipit-source-id: 551e2119768f3a4292713f3bfa83930f5506adbd
This commit is contained in:
albanD 2021-09-17 08:01:33 -07:00 committed by Facebook GitHub Bot
parent a95fabfecb
commit 473e55d5b2
5 changed files with 40 additions and 17 deletions

View File

@ -16,7 +16,8 @@ class WithTorchFunction:
self._tensor = torch.tensor(data, requires_grad=requires_grad)
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
@ -24,7 +25,8 @@ class WithTorchFunction:
class SubWithTorchFunction(torch.Tensor):
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

View File

@ -380,7 +380,8 @@ this time adding a ``__torch_function__`` implementation::
def tensor(self):
return self._value * torch.eye(self._N)
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
@ -500,7 +501,8 @@ handled but to instead pass a :class:`Tensor` to the original :mod:`torch`
function when no override is available. For example, if we change our
implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below::
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
@ -604,12 +606,15 @@ implementation more permissive about what operations are allowed::
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args = [a._t if hasattr(a, '_t') else a for a in args]
metadatas = tuple(a._metadata if hasattr(a, '_metadata') for a in args)
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=self._metadata)
return MetadataTensor(ret, metadata=metadatas[0])
This simple implementation won't necessarily work with every function in the
:mod:`torch` API but it is good enough to capture most common operations::

View File

@ -4,6 +4,7 @@ import inspect
import functools
import pprint
import pickle
import collections
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.overrides import (
@ -128,12 +129,13 @@ class DiagonalTensor(object):
def tensor(self):
return self._i * torch.eye(self._N)
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in self.handled_functions:
if func not in cls.handled_functions:
return NotImplemented
return self.handled_functions[func](*args, **kwargs)
return cls.handled_functions[func](*args, **kwargs)
def __eq__(self, other):
if type(other) is type(self):
@ -203,7 +205,8 @@ class SubTensor(torch.Tensor):
This is useful for testing that the semantics for overriding torch
functions are working correctly.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None):
kwargs = {}
@ -353,7 +356,8 @@ class TensorLike(object):
This class is used to explicitly test that the full torch.tensor API
can be overriden with a class that defines __torch_function__.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if(kwargs is None):
kwargs = {}
@ -674,7 +678,7 @@ def generate_tensor_like_override_tests(cls):
test_method.__name__ = name
setattr(cls, name, test_method)
generate_tensor_like_override_tests(TestTorchFunctionOverride)
# generate_tensor_like_override_tests(TestTorchFunctionOverride)
class Wrapper:
"Basic data container that knows how to unwrap itself"
@ -714,10 +718,19 @@ class Wrapper:
def __getitem__(self, key):
return wrap(self._data[unwrap(key)])
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.used_calls.add(func)
# Find an instance of this class in the arguments
args_of_this_cls = []
for a in args:
if isinstance(a, cls):
args_of_this_cls.append(a)
elif isinstance(a, collections.Sequence):
args_of_this_cls.extend(el for el in a if isinstance(el, cls))
assert len(args_of_this_cls) > 0
args_of_this_cls[0].used_calls.add(func)
args = unwrap(tuple(args))
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
@ -1005,7 +1018,8 @@ class TestDisabledTorchFunction(TestCase):
# Regression test for gh-64687
def test_parameter_does_not_prevent_dispatch(self):
class MyTensor():
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return "called"
t1 = MyTensor()

View File

@ -546,7 +546,8 @@ class ShardedTensor(object):
"""
return self._sharding_spec
def __torch_function__(self, func, types, args=(), kwargs=None):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
raise RuntimeError(f"torch function '{func.__name__}' not supported for ShardedTensor!")
def metadata(self) -> ShardedTensorMetadata:

View File

@ -1529,7 +1529,8 @@ def is_tensor_like(inp):
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike:
... def __torch_function__(self, func, types, args, kwargs):
... @classmethod
... def __torch_function__(cls, func, types, args, kwargs):
... return -1
>>> is_tensor_like(TensorLike())
True