Do not wrap Tensor.{grad,_base} by default (#60464)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60464

Fixes https://github.com/szagoruyko/pytorchviz/issues/65

An alternate implementation of this PR would be to remove the
__torch_function__ interposition points for these accessors entirely.
In the end, I decided to opt for extra expressivity.  See
torch.overrides for the criterion on how I decided which accessors
should get the nowrap treatment.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29302835

Pulled By: ezyang

fbshipit-source-id: fbe0ac4530a6cc9d6759a3fdf5514d4d7b1f7690
This commit is contained in:
Edward Yang 2021-06-22 12:48:07 -07:00 committed by Facebook GitHub Bot
parent f42140cb8a
commit 82c52fd417
3 changed files with 41 additions and 2 deletions

View File

@ -540,6 +540,16 @@ class TestTorchFunctionOverride(TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
sn1 + s2 sn1 + s2
def test_base(self):
# https://github.com/szagoruyko/pytorchviz/issues/65
class DummyTensor(torch.Tensor):
pass
a = torch.ones(1)
c = DummyTensor(a)
self.assertTrue(c._is_view())
self.assertTrue(c._base is a)
def generate_tensor_like_override_tests(cls): def generate_tensor_like_override_tests(cls):
from torch.testing._internal.generated.annotated_fn_args import annotated_args from torch.testing._internal.generated.annotated_fn_args import annotated_args

View File

@ -11,7 +11,7 @@ from torch._namedtensor_internals import (
unzip_namedshape, single_ellipsis_index, is_ellipsis) unzip_namedshape, single_ellipsis_index, is_ellipsis)
from torch.overrides import ( from torch.overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic, has_torch_function, has_torch_function_unary, has_torch_function_variadic,
handle_torch_function) handle_torch_function, get_default_nowrap_functions)
import torch.utils.hooks as hooks import torch.utils.hooks as hooks
@ -1007,6 +1007,9 @@ class Tensor(torch._C._TensorBase):
with _C.DisableTorchFunction(): with _C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return _convert(ret, cls) return _convert(ret, cls)
__module__ = 'torch' __module__ = 'torch'

View File

@ -233,6 +233,32 @@ def get_ignored_functions() -> Set[Callable]:
} }
@functools.lru_cache(None)
def get_default_nowrap_functions() -> Set[Callable]:
"""
Return public functions that do not wrap in a subclass when invoked by
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
these functions represent field accesses (i.e., retrieving a Tensor that
is stored somewhere on the Tensor) as opposed to computation. Users of
these functions expect object identity to be preserved over multiple accesses
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
the fly every time (furthermore, the tensor stored here might already be
the subclass, in which case wrapping really ought not to happen).
Not ALL property accessors have this property; for example ``Tensor.T`` actually
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
these calls (you need to check the implementation of the function to see if
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
it doesn't have to be on this list (though it is harmless if it is).
"""
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
}
@functools.lru_cache(None) @functools.lru_cache(None)
def get_testing_overrides() -> Dict[Callable, Callable]: def get_testing_overrides() -> Dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions """Return a dict containing dummy overrides for all overridable functions