mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Support for tensor subclasses as parameters
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73459 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
parent
54c75e1e8f
commit
bc34cf5fe4
210
test/test_subclass.py
Normal file
210
test/test_subclass.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
# Owner(s): ["module: nn"]
|
||||
|
||||
import tempfile
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests)
|
||||
from torch.testing._internal.common_subclass import subclass_db, DiagTensorBelow
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor
|
||||
from unittest import expectedFailure
|
||||
|
||||
# The current test methodology in this file is to test a variety of real use cases
|
||||
# with a set of fully-fledged tensor subclasses. In the future, this may change
|
||||
# to more narrowly specify toy subclasses for each of the specific invariants under
|
||||
# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.
|
||||
|
||||
|
||||
# Decorator for parametrizing tests across the various tensor classes.
|
||||
parametrize_tensor_cls = parametrize("tensor_cls", [
|
||||
subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()])
|
||||
|
||||
|
||||
class TestSubclass(TestCase):
|
||||
def _create_tensor(self, tensor_cls):
|
||||
return subclass_db[tensor_cls].create_fn(3)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("tensor_requires_grad", [False, True])
|
||||
def test_param_invariants(self, tensor_cls, tensor_requires_grad):
|
||||
x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad)
|
||||
param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
|
||||
|
||||
self.assertIsInstance(param, nn.Parameter)
|
||||
# Ensure requires_grad passed to Parameter's constructor takes precedence.
|
||||
self.assertEqual(param.requires_grad, not tensor_requires_grad)
|
||||
|
||||
# Ensure original tensor is not mutated by Parameter construction.
|
||||
self.assertNotIsInstance(x, nn.Parameter)
|
||||
self.assertEqual(x.requires_grad, tensor_requires_grad)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("as_param", [False, True])
|
||||
def test_deepcopy(self, tensor_cls, as_param):
|
||||
x = self._create_tensor(tensor_cls)
|
||||
if as_param:
|
||||
x = nn.Parameter(x)
|
||||
x_copy = deepcopy(x)
|
||||
self.assertEqual(x, x_copy)
|
||||
self.assertEqual(x.__class__, x_copy.__class__)
|
||||
self.assertIsNot(x, x_copy)
|
||||
self.assertIsInstance(x_copy, tensor_cls)
|
||||
if as_param:
|
||||
# Deepcopy should preserve both custom type and "parameter-ness".
|
||||
self.assertIsInstance(x_copy, nn.Parameter)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("as_param", [False, True])
|
||||
def test_serialization(self, tensor_cls, as_param):
|
||||
with tempfile.TemporaryFile() as f:
|
||||
x = self._create_tensor(tensor_cls)
|
||||
if as_param:
|
||||
x = nn.Parameter(x)
|
||||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
x_loaded = torch.load(f)
|
||||
|
||||
self.assertEqual(x, x_loaded)
|
||||
self.assertIsNot(x, x_loaded)
|
||||
self.assertIsInstance(x_loaded, tensor_cls)
|
||||
if as_param:
|
||||
# Serialization should preserve both custom type and "parameter-ness".
|
||||
self.assertIsInstance(x_loaded, nn.Parameter)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("as_param", [False, True])
|
||||
def test_repr(self, tensor_cls, as_param):
|
||||
x = self._create_tensor(tensor_cls)
|
||||
if as_param:
|
||||
x = nn.Parameter(x)
|
||||
str_repr = x.__repr__()
|
||||
if tensor_cls is not torch.Tensor:
|
||||
self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1)
|
||||
self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("as_param", [False, True])
|
||||
def test_type_propagation(self, tensor_cls, as_param):
|
||||
x = self._create_tensor(tensor_cls)
|
||||
if as_param:
|
||||
x = nn.Parameter(x)
|
||||
|
||||
# Call the add operator to produce an output tensor.
|
||||
output = x + self._create_tensor(torch.Tensor)
|
||||
|
||||
# Custom type should be propagated across operations if closed under the op, but
|
||||
# "parameter-ness" should not be.
|
||||
if subclass_db[tensor_cls].closed_under_ops:
|
||||
self.assertIsInstance(output, tensor_cls)
|
||||
else:
|
||||
self.assertIsInstance(output, torch.Tensor)
|
||||
self.assertNotIsInstance(output, nn.Parameter)
|
||||
|
||||
@parametrize_tensor_cls
|
||||
def test_module_optimization(self, tensor_cls):
|
||||
create_fn = partial(self._create_tensor, tensor_cls)
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p1 = nn.Parameter(create_fn())
|
||||
|
||||
self.p_list = nn.ParameterList([create_fn() for _ in range(3)])
|
||||
self.p_list.append(create_fn())
|
||||
|
||||
self.p_dict = nn.ParameterDict({
|
||||
'foo': create_fn(),
|
||||
'bar': create_fn(),
|
||||
})
|
||||
self.p_dict['baz'] = create_fn()
|
||||
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.p1)
|
||||
for p in self.p_list:
|
||||
nn.init.uniform_(p)
|
||||
for _, p in self.p_dict.items():
|
||||
nn.init.uniform_(p)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.p1 + x
|
||||
for p in self.p_list:
|
||||
out = p + out
|
||||
|
||||
for _, v in self.p_dict.items():
|
||||
out = v + out
|
||||
|
||||
return out
|
||||
|
||||
m = MyModule()
|
||||
self.assertEqual(len(m.state_dict()), 8)
|
||||
|
||||
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
|
||||
m(create_fn()).sum().backward(torch.tensor(1))
|
||||
optimizer.step()
|
||||
|
||||
@parametrize_tensor_cls
|
||||
@parametrize("leave_parametrized", [False, True])
|
||||
def test_parametrization(self, tensor_cls, leave_parametrized):
|
||||
# TODO: Either implement set_() properly for these tensor subclasses or apply a
|
||||
# more general fix to avoid the need for special set_() handling. For now, skip
|
||||
# testing these as they're expected to fail.
|
||||
if tensor_cls in [LoggingTensor, DiagTensorBelow]:
|
||||
return
|
||||
|
||||
create_fn = partial(self._create_tensor, tensor_cls)
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(create_fn())
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight + x
|
||||
|
||||
class MyParametrization(nn.Module):
|
||||
def forward(self, X):
|
||||
return -X
|
||||
|
||||
m = MyModule()
|
||||
self.assertEqual(len(m.state_dict()), 1)
|
||||
register_parametrization(m, 'weight', MyParametrization())
|
||||
self.assertIsInstance(m.weight, tensor_cls)
|
||||
output = m(self._create_tensor(torch.Tensor))
|
||||
self.assertIsInstance(output, tensor_cls)
|
||||
remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized)
|
||||
|
||||
# Lazy modules with custom tensors are not supported yet.
|
||||
@expectedFailure
|
||||
@parametrize_tensor_cls
|
||||
def test_lazy_module(self, tensor_cls):
|
||||
if tensor_cls is torch.Tensor:
|
||||
self.fail('dummy fail for base tensor until the test passes for subclasses')
|
||||
|
||||
class MyLazyModule(LazyModuleMixin, nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = nn.UninitializedParameter()
|
||||
|
||||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||
if self.has_uninitialized_params():
|
||||
with torch.no_grad():
|
||||
self.param.materialize(input.shape)
|
||||
nn.init.uniform_(self.param)
|
||||
|
||||
def forward(self, x):
|
||||
return self.param + x
|
||||
|
||||
m = MyLazyModule()
|
||||
self.assertTrue(m.has_uninitialized_params())
|
||||
output = m(self._create_tensor(tensor_cls))
|
||||
self.assertFalse(m.has_uninitialized_params())
|
||||
self.assertIsInstance(m.param, tensor_cls)
|
||||
|
||||
instantiate_parametrized_tests(TestSubclass)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -333,11 +333,12 @@ class Tensor(torch._C._TensorBase):
|
|||
# See Note [Don't serialize hooks]
|
||||
self.requires_grad, _, self._backward_hooks = state
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self, *, tensor_contents=None):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.__repr__, (self,), self)
|
||||
return handle_torch_function(Tensor.__repr__, (self,), self,
|
||||
tensor_contents=tensor_contents)
|
||||
# All strings are unicode in Python 3.
|
||||
return torch._tensor_str._str(self)
|
||||
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
|
||||
|
||||
def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
|
||||
r"""Computes the gradient of current tensor w.r.t. graph leaves.
|
||||
|
|
|
|||
|
|
@ -297,15 +297,24 @@ def get_summarized_data(self):
|
|||
else:
|
||||
return torch.stack([get_summarized_data(x) for x in self])
|
||||
|
||||
def _str_intern(inp):
|
||||
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
|
||||
prefix = "nested_tensor(" if self.is_nested else 'tensor('
|
||||
def _str_intern(inp, *, tensor_contents=None):
|
||||
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
|
||||
if is_plain_tensor:
|
||||
prefix = 'tensor('
|
||||
elif inp.is_nested:
|
||||
prefix = "nested_tensor("
|
||||
else:
|
||||
prefix = f"{type(inp).__name__}("
|
||||
indent = len(prefix)
|
||||
suffixes = []
|
||||
custom_contents_provided = tensor_contents is not None
|
||||
if custom_contents_provided:
|
||||
tensor_str = tensor_contents
|
||||
|
||||
# This is used to extract the primal value and thus disable the forward AD
|
||||
# within this function.
|
||||
# TODO(albanD) This needs to be updated when more than one level is supported
|
||||
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
|
||||
|
||||
# Note [Print tensor device]:
|
||||
# A general logic here is we only print device when it doesn't match
|
||||
|
|
@ -332,40 +341,42 @@ def _str_intern(inp):
|
|||
suffixes.append('nnz=' + str(self._nnz()))
|
||||
if not has_default_dtype:
|
||||
suffixes.append('dtype=' + str(self.dtype))
|
||||
indices_prefix = 'indices=tensor('
|
||||
indices = self._indices().detach()
|
||||
indices_str = _tensor_str(indices, indent + len(indices_prefix))
|
||||
if indices.numel() == 0:
|
||||
indices_str += ', size=' + str(tuple(indices.shape))
|
||||
values_prefix = 'values=tensor('
|
||||
values = self._values().detach()
|
||||
values_str = _tensor_str(values, indent + len(values_prefix))
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
|
||||
if not custom_contents_provided:
|
||||
indices_prefix = 'indices=tensor('
|
||||
indices = self._indices().detach()
|
||||
indices_str = _tensor_str(indices, indent + len(indices_prefix))
|
||||
if indices.numel() == 0:
|
||||
indices_str += ', size=' + str(tuple(indices.shape))
|
||||
values_prefix = 'values=tensor('
|
||||
values = self._values().detach()
|
||||
values_str = _tensor_str(values, indent + len(values_prefix))
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
|
||||
elif self.is_sparse_csr:
|
||||
suffixes.append('size=' + str(tuple(self.shape)))
|
||||
suffixes.append('nnz=' + str(self._nnz()))
|
||||
if not has_default_dtype:
|
||||
suffixes.append('dtype=' + str(self.dtype))
|
||||
crow_indices_prefix = 'crow_indices=tensor('
|
||||
crow_indices = self.crow_indices().detach()
|
||||
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
|
||||
if crow_indices.numel() == 0:
|
||||
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
|
||||
col_indices_prefix = 'col_indices=tensor('
|
||||
col_indices = self.col_indices().detach()
|
||||
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
|
||||
if col_indices.numel() == 0:
|
||||
col_indices_str += ', size=' + str(tuple(col_indices.shape))
|
||||
values_prefix = 'values=tensor('
|
||||
values = self.values().detach()
|
||||
values_str = _tensor_str(values, indent + len(values_prefix))
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
|
||||
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
|
||||
values_prefix + values_str + ')'
|
||||
if not custom_contents_provided:
|
||||
crow_indices_prefix = 'crow_indices=tensor('
|
||||
crow_indices = self.crow_indices().detach()
|
||||
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
|
||||
if crow_indices.numel() == 0:
|
||||
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
|
||||
col_indices_prefix = 'col_indices=tensor('
|
||||
col_indices = self.col_indices().detach()
|
||||
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
|
||||
if col_indices.numel() == 0:
|
||||
col_indices_str += ', size=' + str(tuple(col_indices.shape))
|
||||
values_prefix = 'values=tensor('
|
||||
values = self.values().detach()
|
||||
values_str = _tensor_str(values, indent + len(values_prefix))
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
|
||||
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
|
||||
values_prefix + values_str + ')'
|
||||
elif self.is_quantized:
|
||||
suffixes.append('size=' + str(tuple(self.shape)))
|
||||
if not has_default_dtype:
|
||||
|
|
@ -379,12 +390,14 @@ def _str_intern(inp):
|
|||
suffixes.append('scale=' + str(self.q_per_channel_scales()))
|
||||
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
|
||||
suffixes.append('axis=' + str(self.q_per_channel_axis()))
|
||||
tensor_str = _tensor_str(self.dequantize(), indent)
|
||||
if not custom_contents_provided:
|
||||
tensor_str = _tensor_str(self.dequantize(), indent)
|
||||
elif self.is_nested:
|
||||
def indented_str(s, indent):
|
||||
return "\n".join(f" {line}" for line in s.split("\n"))
|
||||
strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
|
||||
tensor_str = f"[\n{strs}\n]"
|
||||
if not custom_contents_provided:
|
||||
def indented_str(s, indent):
|
||||
return "\n".join(f" {line}" for line in s.split("\n"))
|
||||
strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
|
||||
tensor_str = f"[\n{strs}\n]"
|
||||
else:
|
||||
if self.is_meta:
|
||||
suffixes.append('size=' + str(tuple(self.shape)))
|
||||
|
|
@ -392,7 +405,8 @@ def _str_intern(inp):
|
|||
suffixes.append('dtype=' + str(self.dtype))
|
||||
# TODO: This implies that ellipses is valid syntax for allocating
|
||||
# a meta tensor, which it could be, but it isn't right now
|
||||
tensor_str = '...'
|
||||
if not custom_contents_provided:
|
||||
tensor_str = '...'
|
||||
else:
|
||||
if self.numel() == 0 and not self.is_sparse:
|
||||
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
||||
|
|
@ -403,15 +417,17 @@ def _str_intern(inp):
|
|||
# should be int64, so it must be shown explicitly.
|
||||
if self.dtype != torch.get_default_dtype():
|
||||
suffixes.append('dtype=' + str(self.dtype))
|
||||
tensor_str = '[]'
|
||||
if not custom_contents_provided:
|
||||
tensor_str = '[]'
|
||||
else:
|
||||
if not has_default_dtype:
|
||||
suffixes.append('dtype=' + str(self.dtype))
|
||||
|
||||
if self.layout != torch.strided:
|
||||
tensor_str = _tensor_str(self.to_dense(), indent)
|
||||
else:
|
||||
tensor_str = _tensor_str(self, indent)
|
||||
if not custom_contents_provided:
|
||||
if self.layout != torch.strided:
|
||||
tensor_str = _tensor_str(self.to_dense(), indent)
|
||||
else:
|
||||
tensor_str = _tensor_str(self, indent)
|
||||
|
||||
if self.layout != torch.strided:
|
||||
suffixes.append('layout=' + str(self.layout))
|
||||
|
|
@ -432,8 +448,17 @@ def _str_intern(inp):
|
|||
if tangent is not None:
|
||||
suffixes.append('tangent={}'.format(tangent))
|
||||
|
||||
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
|
||||
string_repr = _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
|
||||
|
||||
def _str(self):
|
||||
# Check if this instance is flagged as a parameter and change the repr accordingly.
|
||||
# Unfortunately, this function has to be aware of this detail.
|
||||
# NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
|
||||
# this should be done for those as well to produce a valid repr.
|
||||
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
|
||||
string_repr = f"Parameter({string_repr})"
|
||||
|
||||
return string_repr
|
||||
|
||||
def _str(self, *, tensor_contents=None):
|
||||
with torch.no_grad():
|
||||
return _str_intern(self)
|
||||
return _str_intern(self, tensor_contents=tensor_contents)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,15 @@ from torch._C import _disabled_torch_function_impl
|
|||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Parameter(torch.Tensor):
|
||||
# Metaclass to combine _TensorMeta and the instance check override for Parameter.
|
||||
class _ParameterMeta(torch._C._TensorMeta):
|
||||
# Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
|
||||
def __instancecheck__(self, instance):
|
||||
return super().__instancecheck__(instance) or (
|
||||
isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))
|
||||
|
||||
|
||||
class Parameter(torch.Tensor, metaclass=_ParameterMeta):
|
||||
r"""A kind of Tensor that is to be considered a module parameter.
|
||||
|
||||
Parameters are :class:`~torch.Tensor` subclasses, that have a
|
||||
|
|
@ -23,8 +31,18 @@ class Parameter(torch.Tensor):
|
|||
def __new__(cls, data=None, requires_grad=True):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
if type(data) is torch.Tensor:
|
||||
# For ease of BC maintenance, keep this path for standard Tensor.
|
||||
# Eventually (tm), we should change the behavior for standard Tensor to match.
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
|
||||
t = data.detach().requires_grad_(requires_grad)
|
||||
t._is_param = True
|
||||
return t
|
||||
|
||||
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
|
||||
# are still considered that custom tensor type and these methods will not be called for them.
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
|
|
|
|||
|
|
@ -614,7 +614,19 @@ def remove_parametrizations(
|
|||
# We do this so that the parameter does not to change the id()
|
||||
# This way the user does not need to update the optimizer
|
||||
with torch.no_grad():
|
||||
original.set_(t)
|
||||
if type(original) is torch.Tensor:
|
||||
original.set_(t)
|
||||
else:
|
||||
try:
|
||||
original.set_(t)
|
||||
except RuntimeError as e:
|
||||
# TODO: Fix this for tensor subclasses that are parameters:
|
||||
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
|
||||
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
|
||||
"for a parameter that is an instance of a tensor subclass requires "
|
||||
"set_() to be implemented correctly for the tensor subclass. Either "
|
||||
"set leave_parametrized=False or provide a working implementation for "
|
||||
"set_() in the tensor subclass.")
|
||||
else:
|
||||
if leave_parametrized:
|
||||
# We cannot use no_grad because we need to know whether one or more
|
||||
|
|
|
|||
|
|
@ -1108,7 +1108,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
Tensor.__format__: lambda self, format_spec: -1,
|
||||
Tensor.__reduce_ex__: lambda self, proto: -1,
|
||||
Tensor.__reversed__: lambda self: -1,
|
||||
Tensor.__repr__: lambda self: -1,
|
||||
Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
|
||||
Tensor.__setitem__: lambda self, k, v: -1,
|
||||
Tensor.__setstate__: lambda self, d: -1,
|
||||
Tensor.T.__get__: lambda self: -1,
|
||||
|
|
|
|||
219
torch/testing/_internal/common_subclass.py
Normal file
219
torch/testing/_internal/common_subclass.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
import torch
|
||||
from copy import deepcopy
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
# TODO: Move LoggingTensor here.
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor
|
||||
|
||||
|
||||
# Base class for wrapper-style tensors.
|
||||
class WrapperTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, *args, **kwargs):
|
||||
t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
|
||||
if "size" not in kwargs:
|
||||
size = t.size()
|
||||
else:
|
||||
size = kwargs["size"]
|
||||
del kwargs["size"]
|
||||
if "dtype" not in kwargs:
|
||||
kwargs["dtype"] = t.dtype
|
||||
if "layout" not in kwargs:
|
||||
kwargs["layout"] = t.layout
|
||||
if "device" not in kwargs:
|
||||
kwargs["device"] = t.device
|
||||
if "requires_grad" not in kwargs:
|
||||
kwargs["requires_grad"] = False
|
||||
# Ignore memory_format and pin memory for now as I don't know how to
|
||||
# safely access them on a Tensor (if possible??)
|
||||
|
||||
wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
|
||||
wrapper._validate_methods()
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def get_wrapper_properties(cls, *args, **kwargs):
|
||||
# Should return both an example Tensor and a dictionaly of kwargs
|
||||
# to override any of that example Tensor's properly.
|
||||
# This is very similar to the `t.new_*(args)` API
|
||||
raise NotImplementedError("You need to implement get_wrapper_properties")
|
||||
|
||||
def _validate_methods(self):
|
||||
# Skip this if not in debug mode?
|
||||
# Changing these on the python side is wrong as it would not be properly reflected
|
||||
# on the c++ side
|
||||
# This doesn't catch attributes set in the __init__
|
||||
forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
|
||||
for el in forbidden_overrides:
|
||||
if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
|
||||
raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
|
||||
f"property {el} but this is not allowed as such change would "
|
||||
"not be reflected to c++ callers.")
|
||||
|
||||
|
||||
class DiagTensorBelow(WrapperTensor):
|
||||
@classmethod
|
||||
def get_wrapper_properties(cls, diag, requires_grad=False):
|
||||
assert diag.ndim == 1
|
||||
return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
|
||||
|
||||
def __init__(self, diag, requires_grad=False):
|
||||
self.diag = diag
|
||||
|
||||
handled_ops = {}
|
||||
|
||||
# We disable torch function here to avoid any unwanted wrapping of the output
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if not all(issubclass(cls, t) for t in types):
|
||||
return NotImplemented
|
||||
|
||||
# For everything else, call the handler:
|
||||
fn = cls.handled_ops.get(func.__name__, None)
|
||||
if fn:
|
||||
return fn(*args, **kwargs or {})
|
||||
else:
|
||||
# Note that here, because we don't need to provide the autograd formulas
|
||||
# we can have a default "fallback" that creates a plain Tensor based
|
||||
# on the diag elements and calls the func again.
|
||||
|
||||
def unwrap(e):
|
||||
return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
|
||||
|
||||
def wrap(e):
|
||||
if isinstance(e, torch.Tensor) and e.ndim == 1:
|
||||
return DiagTensorBelow(e)
|
||||
if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
|
||||
return DiagTensorBelow(e.diag())
|
||||
return e
|
||||
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
|
||||
return rs
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__(tensor_contents=f"diag={self.diag}")
|
||||
|
||||
|
||||
class SparseTensor(WrapperTensor):
|
||||
@classmethod
|
||||
def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
|
||||
assert values.device == indices.device
|
||||
return values, {"size": size, "requires_grad": requires_grad}
|
||||
|
||||
def __init__(self, size, values, indices, requires_grad=False):
|
||||
self.values = values
|
||||
self.indices = indices
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
|
||||
|
||||
def sparse_to_dense(self):
|
||||
res = torch.zeros(self.size(), dtype=self.values.dtype)
|
||||
res[self.indices.unbind(1)] = self.values
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def from_dense(t):
|
||||
indices = t.nonzero()
|
||||
values = t[indices.unbind(1)]
|
||||
return SparseTensor(t.size(), values, indices)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
func_name = f"{func.__module__}.{func.__name__}"
|
||||
|
||||
res = cls._try_call_special_impl(func_name, args, kwargs)
|
||||
if res is not NotImplemented:
|
||||
return res
|
||||
|
||||
# Otherwise, use a default implementation that construct dense
|
||||
# tensors and use that to compute values
|
||||
def unwrap(e):
|
||||
return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
|
||||
|
||||
# Wrap back all Tensors into our custom class
|
||||
def wrap(e):
|
||||
# Check for zeros and use that to get indices
|
||||
return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
|
||||
return rs
|
||||
|
||||
# To show how things happen later
|
||||
def __rmul__(self, other):
|
||||
return super().__rmul__(other)
|
||||
|
||||
_SPECIAL_IMPLS = {}
|
||||
|
||||
@classmethod
|
||||
def _try_call_special_impl(cls, func, args, kwargs):
|
||||
if func not in cls._SPECIAL_IMPLS:
|
||||
return NotImplemented
|
||||
return cls._SPECIAL_IMPLS[func](args, kwargs)
|
||||
|
||||
|
||||
# Example non-wrapper subclass that stores extra state.
|
||||
class NonWrapperTensor(torch.Tensor):
|
||||
def __new__(cls, data):
|
||||
t = torch.Tensor._make_subclass(cls, data)
|
||||
t.extra_state = {
|
||||
'last_func_called': None
|
||||
}
|
||||
return t
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
if isinstance(result, cls):
|
||||
# Do something with the extra state. For the example here, just store the name of the
|
||||
# last function called (skip for deepcopy so the copy has the same extra state).
|
||||
if func is torch.Tensor.__deepcopy__:
|
||||
result.extra_state = deepcopy(args[0].extra_state)
|
||||
else:
|
||||
result.extra_state = {
|
||||
'last_func_called': func.__name__,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# new_empty() must be defined for deepcopy to work
|
||||
def new_empty(self, shape):
|
||||
return type(self)(torch.empty(shape))
|
||||
|
||||
|
||||
# Class used to store info about subclass tensors used in testing.
|
||||
class SubclassInfo:
|
||||
|
||||
__slots__ = ['name', 'create_fn', 'closed_under_ops']
|
||||
|
||||
def __init__(self, name, create_fn, closed_under_ops=True):
|
||||
self.name = name
|
||||
self.create_fn = create_fn # create_fn(shape) -> tensor instance
|
||||
self.closed_under_ops = closed_under_ops
|
||||
|
||||
|
||||
subclass_db = {
|
||||
torch.Tensor: SubclassInfo(
|
||||
'base_tensor', create_fn=lambda shape: torch.randn(shape)
|
||||
),
|
||||
NonWrapperTensor: SubclassInfo(
|
||||
'non_wrapper_tensor',
|
||||
create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
|
||||
),
|
||||
LoggingTensor: SubclassInfo(
|
||||
'logging_tensor',
|
||||
create_fn=lambda shape: LoggingTensor(torch.randn(shape))
|
||||
),
|
||||
SparseTensor: SubclassInfo(
|
||||
'sparse_tensor',
|
||||
create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
|
||||
),
|
||||
DiagTensorBelow: SubclassInfo(
|
||||
'diag_tensor_below',
|
||||
create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
|
||||
closed_under_ops=False # sparse semantics
|
||||
),
|
||||
}
|
||||
|
|
@ -57,7 +57,7 @@ class LoggingTensor(torch.Tensor):
|
|||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.elem})"
|
||||
return super().__repr__(tensor_contents=f"{self.elem}")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user