diff --git a/test/test_overrides.py b/test/test_overrides.py index cf18420f884..f2baec461a5 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -540,6 +540,16 @@ class TestTorchFunctionOverride(TestCase): with self.assertRaises(TypeError): sn1 + s2 + def test_base(self): + # https://github.com/szagoruyko/pytorchviz/issues/65 + class DummyTensor(torch.Tensor): + pass + + a = torch.ones(1) + c = DummyTensor(a) + self.assertTrue(c._is_view()) + self.assertTrue(c._base is a) + def generate_tensor_like_override_tests(cls): from torch.testing._internal.generated.annotated_fn_args import annotated_args diff --git a/torch/_tensor.py b/torch/_tensor.py index 7cb99a3c070..5b67ca03f08 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -11,7 +11,7 @@ from torch._namedtensor_internals import ( unzip_namedshape, single_ellipsis_index, is_ellipsis) from torch.overrides import ( has_torch_function, has_torch_function_unary, has_torch_function_variadic, - handle_torch_function) + handle_torch_function, get_default_nowrap_functions) import torch.utils.hooks as hooks @@ -1007,7 +1007,10 @@ class Tensor(torch._C._TensorBase): with _C.DisableTorchFunction(): ret = func(*args, **kwargs) - return _convert(ret, cls) + if func in get_default_nowrap_functions(): + return ret + else: + return _convert(ret, cls) __module__ = 'torch' diff --git a/torch/overrides.py b/torch/overrides.py index 03027f47eef..4c894d8e687 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -233,6 +233,32 @@ def get_ignored_functions() -> Set[Callable]: } +@functools.lru_cache(None) +def get_default_nowrap_functions() -> Set[Callable]: + """ + Return public functions that do not wrap in a subclass when invoked by + the default ``Tensor.__torch_function__`` that preserves subclasses. Typically, + these functions represent field accesses (i.e., retrieving a Tensor that + is stored somewhere on the Tensor) as opposed to computation. Users of + these functions expect object identity to be preserved over multiple accesses + (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on + the fly every time (furthermore, the tensor stored here might already be + the subclass, in which case wrapping really ought not to happen). + + Not ALL property accessors have this property; for example ``Tensor.T`` actually + just creates a new transposed tensor on the fly, and so we SHOULD interpose on + these calls (you need to check the implementation of the function to see if + this is the case or not). Additionally, if a property accessor doesn't return a Tensor, + it doesn't have to be on this list (though it is harmless if it is). + """ + Tensor = torch.Tensor + return { + Tensor._base.__get__, + Tensor.grad.__get__, + Tensor._grad.__get__, + } + + @functools.lru_cache(None) def get_testing_overrides() -> Dict[Callable, Callable]: """Return a dict containing dummy overrides for all overridable functions