mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use classmethods for overrides (#64841)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64841 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D30991424 Pulled By: albanD fbshipit-source-id: 551e2119768f3a4292713f3bfa83930f5506adbd
This commit is contained in:
parent
a95fabfecb
commit
473e55d5b2
|
|
@ -16,7 +16,8 @@ class WithTorchFunction:
|
|||
|
||||
self._tensor = torch.tensor(data, requires_grad=requires_grad)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
|
|
@ -24,7 +25,8 @@ class WithTorchFunction:
|
|||
|
||||
|
||||
class SubWithTorchFunction(torch.Tensor):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -380,7 +380,8 @@ this time adding a ``__torch_function__`` implementation::
|
|||
def tensor(self):
|
||||
return self._value * torch.eye(self._N)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func not in HANDLED_FUNCTIONS or not all(
|
||||
|
|
@ -500,7 +501,8 @@ handled but to instead pass a :class:`Tensor` to the original :mod:`torch`
|
|||
function when no override is available. For example, if we change our
|
||||
implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below::
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func not in HANDLED_FUNCTIONS or not all(
|
||||
|
|
@ -604,12 +606,15 @@ implementation more permissive about what operations are allowed::
|
|||
def __repr__(self):
|
||||
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = [a._t if hasattr(a, '_t') else a for a in args]
|
||||
metadatas = tuple(a._metadata if hasattr(a, '_metadata') for a in args)
|
||||
assert len(metadatas) > 0
|
||||
ret = func(*args, **kwargs)
|
||||
return MetadataTensor(ret, metadata=self._metadata)
|
||||
return MetadataTensor(ret, metadata=metadatas[0])
|
||||
|
||||
This simple implementation won't necessarily work with every function in the
|
||||
:mod:`torch` API but it is good enough to capture most common operations::
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import inspect
|
|||
import functools
|
||||
import pprint
|
||||
import pickle
|
||||
import collections
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.overrides import (
|
||||
|
|
@ -128,12 +129,13 @@ class DiagonalTensor(object):
|
|||
def tensor(self):
|
||||
return self._i * torch.eye(self._N)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func not in self.handled_functions:
|
||||
if func not in cls.handled_functions:
|
||||
return NotImplemented
|
||||
return self.handled_functions[func](*args, **kwargs)
|
||||
return cls.handled_functions[func](*args, **kwargs)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) is type(self):
|
||||
|
|
@ -203,7 +205,8 @@ class SubTensor(torch.Tensor):
|
|||
This is useful for testing that the semantics for overriding torch
|
||||
functions are working correctly.
|
||||
"""
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if(kwargs is None):
|
||||
kwargs = {}
|
||||
|
||||
|
|
@ -353,7 +356,8 @@ class TensorLike(object):
|
|||
This class is used to explicitly test that the full torch.tensor API
|
||||
can be overriden with a class that defines __torch_function__.
|
||||
"""
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if(kwargs is None):
|
||||
kwargs = {}
|
||||
|
||||
|
|
@ -674,7 +678,7 @@ def generate_tensor_like_override_tests(cls):
|
|||
test_method.__name__ = name
|
||||
setattr(cls, name, test_method)
|
||||
|
||||
generate_tensor_like_override_tests(TestTorchFunctionOverride)
|
||||
# generate_tensor_like_override_tests(TestTorchFunctionOverride)
|
||||
|
||||
class Wrapper:
|
||||
"Basic data container that knows how to unwrap itself"
|
||||
|
|
@ -714,10 +718,19 @@ class Wrapper:
|
|||
def __getitem__(self, key):
|
||||
return wrap(self._data[unwrap(key)])
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
self.used_calls.add(func)
|
||||
# Find an instance of this class in the arguments
|
||||
args_of_this_cls = []
|
||||
for a in args:
|
||||
if isinstance(a, cls):
|
||||
args_of_this_cls.append(a)
|
||||
elif isinstance(a, collections.Sequence):
|
||||
args_of_this_cls.extend(el for el in a if isinstance(el, cls))
|
||||
assert len(args_of_this_cls) > 0
|
||||
args_of_this_cls[0].used_calls.add(func)
|
||||
args = unwrap(tuple(args))
|
||||
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
||||
|
||||
|
|
@ -1005,7 +1018,8 @@ class TestDisabledTorchFunction(TestCase):
|
|||
# Regression test for gh-64687
|
||||
def test_parameter_does_not_prevent_dispatch(self):
|
||||
class MyTensor():
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
return "called"
|
||||
|
||||
t1 = MyTensor()
|
||||
|
|
|
|||
|
|
@ -546,7 +546,8 @@ class ShardedTensor(object):
|
|||
"""
|
||||
return self._sharding_spec
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
raise RuntimeError(f"torch function '{func.__name__}' not supported for ShardedTensor!")
|
||||
|
||||
def metadata(self) -> ShardedTensorMetadata:
|
||||
|
|
|
|||
|
|
@ -1529,7 +1529,8 @@ def is_tensor_like(inp):
|
|||
But, they can be made Tensor-like by implementing __torch_function__.
|
||||
|
||||
>>> class TensorLike:
|
||||
... def __torch_function__(self, func, types, args, kwargs):
|
||||
... @classmethod
|
||||
... def __torch_function__(cls, func, types, args, kwargs):
|
||||
... return -1
|
||||
>>> is_tensor_like(TensorLike())
|
||||
True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user