mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Do not wrap Tensor.{grad,_base} by default (#60464)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60464 Fixes https://github.com/szagoruyko/pytorchviz/issues/65 An alternate implementation of this PR would be to remove the __torch_function__ interposition points for these accessors entirely. In the end, I decided to opt for extra expressivity. See torch.overrides for the criterion on how I decided which accessors should get the nowrap treatment. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29302835 Pulled By: ezyang fbshipit-source-id: fbe0ac4530a6cc9d6759a3fdf5514d4d7b1f7690
This commit is contained in:
parent
f42140cb8a
commit
82c52fd417
|
|
@ -540,6 +540,16 @@ class TestTorchFunctionOverride(TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
sn1 + s2
|
||||
|
||||
def test_base(self):
|
||||
# https://github.com/szagoruyko/pytorchviz/issues/65
|
||||
class DummyTensor(torch.Tensor):
|
||||
pass
|
||||
|
||||
a = torch.ones(1)
|
||||
c = DummyTensor(a)
|
||||
self.assertTrue(c._is_view())
|
||||
self.assertTrue(c._base is a)
|
||||
|
||||
|
||||
def generate_tensor_like_override_tests(cls):
|
||||
from torch.testing._internal.generated.annotated_fn_args import annotated_args
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from torch._namedtensor_internals import (
|
|||
unzip_namedshape, single_ellipsis_index, is_ellipsis)
|
||||
from torch.overrides import (
|
||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
||||
handle_torch_function)
|
||||
handle_torch_function, get_default_nowrap_functions)
|
||||
import torch.utils.hooks as hooks
|
||||
|
||||
|
||||
|
|
@ -1007,7 +1007,10 @@ class Tensor(torch._C._TensorBase):
|
|||
|
||||
with _C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
return _convert(ret, cls)
|
||||
if func in get_default_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
return _convert(ret, cls)
|
||||
|
||||
__module__ = 'torch'
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,32 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_default_nowrap_functions() -> Set[Callable]:
|
||||
"""
|
||||
Return public functions that do not wrap in a subclass when invoked by
|
||||
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
|
||||
these functions represent field accesses (i.e., retrieving a Tensor that
|
||||
is stored somewhere on the Tensor) as opposed to computation. Users of
|
||||
these functions expect object identity to be preserved over multiple accesses
|
||||
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
|
||||
the fly every time (furthermore, the tensor stored here might already be
|
||||
the subclass, in which case wrapping really ought not to happen).
|
||||
|
||||
Not ALL property accessors have this property; for example ``Tensor.T`` actually
|
||||
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
|
||||
these calls (you need to check the implementation of the function to see if
|
||||
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
|
||||
it doesn't have to be on this list (though it is harmless if it is).
|
||||
"""
|
||||
Tensor = torch.Tensor
|
||||
return {
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
"""Return a dict containing dummy overrides for all overridable functions
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user