Support for tensor subclasses as parameters

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73459

Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser 2022-04-27 10:38:59 -04:00 committed by PyTorch MergeBot
parent 54c75e1e8f
commit bc34cf5fe4
8 changed files with 539 additions and 54 deletions

210
test/test_subclass.py Normal file
View File

@ -0,0 +1,210 @@
# Owner(s): ["module: nn"]
import tempfile
import torch
from copy import deepcopy
from functools import partial
from torch import nn
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from torch.nn.modules.lazy import LazyModuleMixin
from torch.testing._internal.common_utils import (
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests)
from torch.testing._internal.common_subclass import subclass_db, DiagTensorBelow
from torch.testing._internal.logging_tensor import LoggingTensor
from unittest import expectedFailure
# The current test methodology in this file is to test a variety of real use cases
# with a set of fully-fledged tensor subclasses. In the future, this may change
# to more narrowly specify toy subclasses for each of the specific invariants under
# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.
# Decorator for parametrizing tests across the various tensor classes.
parametrize_tensor_cls = parametrize("tensor_cls", [
subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()])
class TestSubclass(TestCase):
def _create_tensor(self, tensor_cls):
return subclass_db[tensor_cls].create_fn(3)
@parametrize_tensor_cls
@parametrize("tensor_requires_grad", [False, True])
def test_param_invariants(self, tensor_cls, tensor_requires_grad):
x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad)
param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
self.assertIsInstance(param, nn.Parameter)
# Ensure requires_grad passed to Parameter's constructor takes precedence.
self.assertEqual(param.requires_grad, not tensor_requires_grad)
# Ensure original tensor is not mutated by Parameter construction.
self.assertNotIsInstance(x, nn.Parameter)
self.assertEqual(x.requires_grad, tensor_requires_grad)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_deepcopy(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
x_copy = deepcopy(x)
self.assertEqual(x, x_copy)
self.assertEqual(x.__class__, x_copy.__class__)
self.assertIsNot(x, x_copy)
self.assertIsInstance(x_copy, tensor_cls)
if as_param:
# Deepcopy should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_copy, nn.Parameter)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_serialization(self, tensor_cls, as_param):
with tempfile.TemporaryFile() as f:
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
torch.save(x, f)
f.seek(0)
x_loaded = torch.load(f)
self.assertEqual(x, x_loaded)
self.assertIsNot(x, x_loaded)
self.assertIsInstance(x_loaded, tensor_cls)
if as_param:
# Serialization should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_loaded, nn.Parameter)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_repr(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
str_repr = x.__repr__()
if tensor_cls is not torch.Tensor:
self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1)
self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_type_propagation(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
# Call the add operator to produce an output tensor.
output = x + self._create_tensor(torch.Tensor)
# Custom type should be propagated across operations if closed under the op, but
# "parameter-ness" should not be.
if subclass_db[tensor_cls].closed_under_ops:
self.assertIsInstance(output, tensor_cls)
else:
self.assertIsInstance(output, torch.Tensor)
self.assertNotIsInstance(output, nn.Parameter)
@parametrize_tensor_cls
def test_module_optimization(self, tensor_cls):
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.p1 = nn.Parameter(create_fn())
self.p_list = nn.ParameterList([create_fn() for _ in range(3)])
self.p_list.append(create_fn())
self.p_dict = nn.ParameterDict({
'foo': create_fn(),
'bar': create_fn(),
})
self.p_dict['baz'] = create_fn()
with torch.no_grad():
nn.init.normal_(self.p1)
for p in self.p_list:
nn.init.uniform_(p)
for _, p in self.p_dict.items():
nn.init.uniform_(p)
def forward(self, x):
out = self.p1 + x
for p in self.p_list:
out = p + out
for _, v in self.p_dict.items():
out = v + out
return out
m = MyModule()
self.assertEqual(len(m.state_dict()), 8)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
m(create_fn()).sum().backward(torch.tensor(1))
optimizer.step()
@parametrize_tensor_cls
@parametrize("leave_parametrized", [False, True])
def test_parametrization(self, tensor_cls, leave_parametrized):
# TODO: Either implement set_() properly for these tensor subclasses or apply a
# more general fix to avoid the need for special set_() handling. For now, skip
# testing these as they're expected to fail.
if tensor_cls in [LoggingTensor, DiagTensorBelow]:
return
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(create_fn())
def forward(self, x):
return self.weight + x
class MyParametrization(nn.Module):
def forward(self, X):
return -X
m = MyModule()
self.assertEqual(len(m.state_dict()), 1)
register_parametrization(m, 'weight', MyParametrization())
self.assertIsInstance(m.weight, tensor_cls)
output = m(self._create_tensor(torch.Tensor))
self.assertIsInstance(output, tensor_cls)
remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized)
# Lazy modules with custom tensors are not supported yet.
@expectedFailure
@parametrize_tensor_cls
def test_lazy_module(self, tensor_cls):
if tensor_cls is torch.Tensor:
self.fail('dummy fail for base tensor until the test passes for subclasses')
class MyLazyModule(LazyModuleMixin, nn.Module):
def __init__(self):
super().__init__()
self.param = nn.UninitializedParameter()
def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
with torch.no_grad():
self.param.materialize(input.shape)
nn.init.uniform_(self.param)
def forward(self, x):
return self.param + x
m = MyLazyModule()
self.assertTrue(m.has_uninitialized_params())
output = m(self._create_tensor(tensor_cls))
self.assertFalse(m.has_uninitialized_params())
self.assertIsInstance(m.param, tensor_cls)
instantiate_parametrized_tests(TestSubclass)
if __name__ == '__main__':
run_tests()

View File

@ -333,11 +333,12 @@ class Tensor(torch._C._TensorBase):
# See Note [Don't serialize hooks]
self.requires_grad, _, self._backward_hooks = state
def __repr__(self):
def __repr__(self, *, tensor_contents=None):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__repr__, (self,), self)
return handle_torch_function(Tensor.__repr__, (self,), self,
tensor_contents=tensor_contents)
# All strings are unicode in Python 3.
return torch._tensor_str._str(self)
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
r"""Computes the gradient of current tensor w.r.t. graph leaves.

View File

@ -297,15 +297,24 @@ def get_summarized_data(self):
else:
return torch.stack([get_summarized_data(x) for x in self])
def _str_intern(inp):
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
prefix = "nested_tensor(" if self.is_nested else 'tensor('
def _str_intern(inp, *, tensor_contents=None):
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
if is_plain_tensor:
prefix = 'tensor('
elif inp.is_nested:
prefix = "nested_tensor("
else:
prefix = f"{type(inp).__name__}("
indent = len(prefix)
suffixes = []
custom_contents_provided = tensor_contents is not None
if custom_contents_provided:
tensor_str = tensor_contents
# This is used to extract the primal value and thus disable the forward AD
# within this function.
# TODO(albanD) This needs to be updated when more than one level is supported
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
# Note [Print tensor device]:
# A general logic here is we only print device when it doesn't match
@ -332,40 +341,42 @@ def _str_intern(inp):
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
indices_prefix = 'indices=tensor('
indices = self._indices().detach()
indices_str = _tensor_str(indices, indent + len(indices_prefix))
if indices.numel() == 0:
indices_str += ', size=' + str(tuple(indices.shape))
values_prefix = 'values=tensor('
values = self._values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
if not custom_contents_provided:
indices_prefix = 'indices=tensor('
indices = self._indices().detach()
indices_str = _tensor_str(indices, indent + len(indices_prefix))
if indices.numel() == 0:
indices_str += ', size=' + str(tuple(indices.shape))
values_prefix = 'values=tensor('
values = self._values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
elif self.is_sparse_csr:
suffixes.append('size=' + str(tuple(self.shape)))
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
crow_indices_prefix = 'crow_indices=tensor('
crow_indices = self.crow_indices().detach()
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
if crow_indices.numel() == 0:
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
col_indices_prefix = 'col_indices=tensor('
col_indices = self.col_indices().detach()
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
if col_indices.numel() == 0:
col_indices_str += ', size=' + str(tuple(col_indices.shape))
values_prefix = 'values=tensor('
values = self.values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
values_prefix + values_str + ')'
if not custom_contents_provided:
crow_indices_prefix = 'crow_indices=tensor('
crow_indices = self.crow_indices().detach()
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
if crow_indices.numel() == 0:
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
col_indices_prefix = 'col_indices=tensor('
col_indices = self.col_indices().detach()
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
if col_indices.numel() == 0:
col_indices_str += ', size=' + str(tuple(col_indices.shape))
values_prefix = 'values=tensor('
values = self.values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
values_prefix + values_str + ')'
elif self.is_quantized:
suffixes.append('size=' + str(tuple(self.shape)))
if not has_default_dtype:
@ -379,12 +390,14 @@ def _str_intern(inp):
suffixes.append('scale=' + str(self.q_per_channel_scales()))
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
suffixes.append('axis=' + str(self.q_per_channel_axis()))
tensor_str = _tensor_str(self.dequantize(), indent)
if not custom_contents_provided:
tensor_str = _tensor_str(self.dequantize(), indent)
elif self.is_nested:
def indented_str(s, indent):
return "\n".join(f" {line}" for line in s.split("\n"))
strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
tensor_str = f"[\n{strs}\n]"
if not custom_contents_provided:
def indented_str(s, indent):
return "\n".join(f" {line}" for line in s.split("\n"))
strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
tensor_str = f"[\n{strs}\n]"
else:
if self.is_meta:
suffixes.append('size=' + str(tuple(self.shape)))
@ -392,7 +405,8 @@ def _str_intern(inp):
suffixes.append('dtype=' + str(self.dtype))
# TODO: This implies that ellipses is valid syntax for allocating
# a meta tensor, which it could be, but it isn't right now
tensor_str = '...'
if not custom_contents_provided:
tensor_str = '...'
else:
if self.numel() == 0 and not self.is_sparse:
# Explicitly print the shape if it is not (0,), to match NumPy behavior
@ -403,15 +417,17 @@ def _str_intern(inp):
# should be int64, so it must be shown explicitly.
if self.dtype != torch.get_default_dtype():
suffixes.append('dtype=' + str(self.dtype))
tensor_str = '[]'
if not custom_contents_provided:
tensor_str = '[]'
else:
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
if self.layout != torch.strided:
tensor_str = _tensor_str(self.to_dense(), indent)
else:
tensor_str = _tensor_str(self, indent)
if not custom_contents_provided:
if self.layout != torch.strided:
tensor_str = _tensor_str(self.to_dense(), indent)
else:
tensor_str = _tensor_str(self, indent)
if self.layout != torch.strided:
suffixes.append('layout=' + str(self.layout))
@ -432,8 +448,17 @@ def _str_intern(inp):
if tangent is not None:
suffixes.append('tangent={}'.format(tangent))
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
string_repr = _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
def _str(self):
# Check if this instance is flagged as a parameter and change the repr accordingly.
# Unfortunately, this function has to be aware of this detail.
# NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
# this should be done for those as well to produce a valid repr.
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
string_repr = f"Parameter({string_repr})"
return string_repr
def _str(self, *, tensor_contents=None):
with torch.no_grad():
return _str_intern(self)
return _str_intern(self, tensor_contents=tensor_contents)

View File

@ -3,7 +3,15 @@ from torch._C import _disabled_torch_function_impl
from collections import OrderedDict
class Parameter(torch.Tensor):
# Metaclass to combine _TensorMeta and the instance check override for Parameter.
class _ParameterMeta(torch._C._TensorMeta):
# Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
def __instancecheck__(self, instance):
return super().__instancecheck__(instance) or (
isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))
class Parameter(torch.Tensor, metaclass=_ParameterMeta):
r"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
@ -23,8 +31,18 @@ class Parameter(torch.Tensor):
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
if type(data) is torch.Tensor:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return torch.Tensor._make_subclass(cls, data, requires_grad)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
t._is_param = True
return t
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
# are still considered that custom tensor type and these methods will not be called for them.
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]

View File

@ -614,7 +614,19 @@ def remove_parametrizations(
# We do this so that the parameter does not to change the id()
# This way the user does not need to update the optimizer
with torch.no_grad():
original.set_(t)
if type(original) is torch.Tensor:
original.set_(t)
else:
try:
original.set_(t)
except RuntimeError as e:
# TODO: Fix this for tensor subclasses that are parameters:
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
"for a parameter that is an instance of a tensor subclass requires "
"set_() to be implemented correctly for the tensor subclass. Either "
"set leave_parametrized=False or provide a working implementation for "
"set_() in the tensor subclass.")
else:
if leave_parametrized:
# We cannot use no_grad because we need to know whether one or more

View File

@ -1108,7 +1108,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.__format__: lambda self, format_spec: -1,
Tensor.__reduce_ex__: lambda self, proto: -1,
Tensor.__reversed__: lambda self: -1,
Tensor.__repr__: lambda self: -1,
Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
Tensor.__setitem__: lambda self, k, v: -1,
Tensor.__setstate__: lambda self, d: -1,
Tensor.T.__get__: lambda self: -1,

View File

@ -0,0 +1,219 @@
import torch
from copy import deepcopy
from torch.utils._pytree import tree_map
# TODO: Move LoggingTensor here.
from torch.testing._internal.logging_tensor import LoggingTensor
# Base class for wrapper-style tensors.
class WrapperTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
if "size" not in kwargs:
size = t.size()
else:
size = kwargs["size"]
del kwargs["size"]
if "dtype" not in kwargs:
kwargs["dtype"] = t.dtype
if "layout" not in kwargs:
kwargs["layout"] = t.layout
if "device" not in kwargs:
kwargs["device"] = t.device
if "requires_grad" not in kwargs:
kwargs["requires_grad"] = False
# Ignore memory_format and pin memory for now as I don't know how to
# safely access them on a Tensor (if possible??)
wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
wrapper._validate_methods()
return wrapper
@classmethod
def get_wrapper_properties(cls, *args, **kwargs):
# Should return both an example Tensor and a dictionaly of kwargs
# to override any of that example Tensor's properly.
# This is very similar to the `t.new_*(args)` API
raise NotImplementedError("You need to implement get_wrapper_properties")
def _validate_methods(self):
# Skip this if not in debug mode?
# Changing these on the python side is wrong as it would not be properly reflected
# on the c++ side
# This doesn't catch attributes set in the __init__
forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
for el in forbidden_overrides:
if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
f"property {el} but this is not allowed as such change would "
"not be reflected to c++ callers.")
class DiagTensorBelow(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, diag, requires_grad=False):
assert diag.ndim == 1
return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
def __init__(self, diag, requires_grad=False):
self.diag = diag
handled_ops = {}
# We disable torch function here to avoid any unwanted wrapping of the output
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# For everything else, call the handler:
fn = cls.handled_ops.get(func.__name__, None)
if fn:
return fn(*args, **kwargs or {})
else:
# Note that here, because we don't need to provide the autograd formulas
# we can have a default "fallback" that creates a plain Tensor based
# on the diag elements and calls the func again.
def unwrap(e):
return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
def wrap(e):
if isinstance(e, torch.Tensor) and e.ndim == 1:
return DiagTensorBelow(e)
if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
return DiagTensorBelow(e.diag())
return e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"diag={self.diag}")
class SparseTensor(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
assert values.device == indices.device
return values, {"size": size, "requires_grad": requires_grad}
def __init__(self, size, values, indices, requires_grad=False):
self.values = values
self.indices = indices
def __repr__(self):
return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
def sparse_to_dense(self):
res = torch.zeros(self.size(), dtype=self.values.dtype)
res[self.indices.unbind(1)] = self.values
return res
@staticmethod
def from_dense(t):
indices = t.nonzero()
values = t[indices.unbind(1)]
return SparseTensor(t.size(), values, indices)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
func_name = f"{func.__module__}.{func.__name__}"
res = cls._try_call_special_impl(func_name, args, kwargs)
if res is not NotImplemented:
return res
# Otherwise, use a default implementation that construct dense
# tensors and use that to compute values
def unwrap(e):
return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
# Wrap back all Tensors into our custom class
def wrap(e):
# Check for zeros and use that to get indices
return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
# To show how things happen later
def __rmul__(self, other):
return super().__rmul__(other)
_SPECIAL_IMPLS = {}
@classmethod
def _try_call_special_impl(cls, func, args, kwargs):
if func not in cls._SPECIAL_IMPLS:
return NotImplemented
return cls._SPECIAL_IMPLS[func](args, kwargs)
# Example non-wrapper subclass that stores extra state.
class NonWrapperTensor(torch.Tensor):
def __new__(cls, data):
t = torch.Tensor._make_subclass(cls, data)
t.extra_state = {
'last_func_called': None
}
return t
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, cls):
# Do something with the extra state. For the example here, just store the name of the
# last function called (skip for deepcopy so the copy has the same extra state).
if func is torch.Tensor.__deepcopy__:
result.extra_state = deepcopy(args[0].extra_state)
else:
result.extra_state = {
'last_func_called': func.__name__,
}
return result
# new_empty() must be defined for deepcopy to work
def new_empty(self, shape):
return type(self)(torch.empty(shape))
# Class used to store info about subclass tensors used in testing.
class SubclassInfo:
__slots__ = ['name', 'create_fn', 'closed_under_ops']
def __init__(self, name, create_fn, closed_under_ops=True):
self.name = name
self.create_fn = create_fn # create_fn(shape) -> tensor instance
self.closed_under_ops = closed_under_ops
subclass_db = {
torch.Tensor: SubclassInfo(
'base_tensor', create_fn=lambda shape: torch.randn(shape)
),
NonWrapperTensor: SubclassInfo(
'non_wrapper_tensor',
create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
),
LoggingTensor: SubclassInfo(
'logging_tensor',
create_fn=lambda shape: LoggingTensor(torch.randn(shape))
),
SparseTensor: SubclassInfo(
'sparse_tensor',
create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
),
DiagTensorBelow: SubclassInfo(
'diag_tensor_below',
create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
closed_under_ops=False # sparse semantics
),
}

View File

@ -57,7 +57,7 @@ class LoggingTensor(torch.Tensor):
return r
def __repr__(self):
return f"{self.__class__.__name__}({self.elem})"
return super().__repr__(tensor_contents=f"{self.elem}")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):