pytorch/torch/nn/parameter.py
Shenxiu Liu ec714e33a3 [PT] Allowing deepcopy in unitialized parameter (#83809)
Summary: UninitializedParameter overrides `__new__` method thus the parent class's `__deepcopy__` method doesn't work anymore, causing models using LazyModule cannot be instantiated.

Test Plan:
locally copied lazy module.

After change:
```
shenxiu@devbig1109:fbcode  (5c57dd833)$ bento console --kernel pytorch --local
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/scribeutil/lib.py:9: DeprecationWarning: The "thrift" clients in libfb.py.thrift_clients are not proper thrift clients, and often have unexpected or incorrect behaviour. They are also completely unsupported. Please use a supported client from https://fburl.com/srpy or a supported raw thrift client if you cannot use ServiceRouter.
  from libfb.py.thrift_clients.scribe_thrift_client import ScribeThriftClient
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/ipykernel/iostream.py:14: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  from imp import lock_held as import_lock_held
Python 3.8.6 (default, Jun 10 2022, 04:32:13)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.21.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import copy
   ...: import torch
   ...:
   ...: class LazyModule(torch.nn.Module):
   ...:     def __init__(self):
   ...:         super().__init__()
   ...:         self.m = torch.nn.LazyLinear(10)
   ...:
   ...:     def forward(self, input):
   ...:         x = self.m(input)
   ...:         return x
   ...:
   ...: m = LazyModule()
   ...: print(m.state_dict())
copy.deepcopy(m)
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/mpmath/ctx_mp_python.py:892: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if other is 0:
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/mpmath/ctx_mp_python.py:986: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if other is 0:
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/sympy/solvers/diophantine.py:3188: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if feasible is 1:  # it's prime and k == 2
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/sympy/plotting/plot.py:520: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if self.xscale is 'log':
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/sympy/plotting/plot.py:540: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if self.xscale is 'log':
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/sympy/plotting/plot.py:553: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if self.xscale is 'log':
/data/users/shenxiu/fbsource/buck-out/v2/gen/fbcode/26f2c80c27f9e71d/bento/kernels/__bento_kernel_pytorch__/bento_kernel_pytorch#link-tree/sympy/plotting/plot.py:560: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if self.xscale is 'log':
OrderedDict([('m.weight', <UninitializedParameter>), ('m.bias', <UninitializedParameter>)])

In [2]: copy.deepcopy(m)
Out[2]:
LazyModule(
  (m): LazyLinear(in_features=0, out_features=10, bias=True)
)
```

Before change, above code will give
```
TypeError: empty() received an invalid combination of arguments - got (int, dtype=NoneType, device=bool), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of SymInts size, *, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

```

Cloned n2369721 locally and successful (thru console not notebook because somehow bento notebook doesn't work with buck2 well).

Reviewed By: avilay

Differential Revision: D38866072

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83809
Approved by: https://github.com/ngimel
2022-08-30 05:16:19 +00:00

209 lines
8.8 KiB
Python

import torch
from torch._C import _disabled_torch_function_impl
from collections import OrderedDict
# Metaclass to combine _TensorMeta and the instance check override for Parameter.
class _ParameterMeta(torch._C._TensorMeta):
# Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
def __instancecheck__(self, instance):
return super().__instancecheck__(instance) or (
isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))
class Parameter(torch.Tensor, metaclass=_ParameterMeta):
r"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`Module` s - when they're
assigned as Module attributes they are automatically added to the list of
its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
Assigning a Tensor doesn't have such effect. This is because one might
want to cache some temporary state, like last hidden state of the RNN, in
the model. If there was no such class as :class:`Parameter`, these
temporaries would get registered too.
Args:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`locally-disable-grad-doc` for more details. Default: `True`
"""
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.empty(0)
if type(data) is torch.Tensor or type(data) is Parameter:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return torch.Tensor._make_subclass(cls, data, requires_grad)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data):
raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} "
"requires that detach() returns an instance of the same type, but return "
f"type {type(t).__name__} was found instead. To use the type as a "
"Parameter, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation.")
t._is_param = True
return t
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
# are still considered that custom tensor type and these methods will not be called for them.
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
__torch_function__ = _disabled_torch_function_impl
class UninitializedTensorMixin:
_allowed_methods = [
torch.Tensor.__hash__,
torch.Tensor.size,
torch.Tensor.copy_,
torch.Tensor.is_floating_point,
torch.Tensor.half,
torch.Tensor.float,
torch.Tensor.double,
torch.Tensor.char,
torch.Tensor.short,
torch.Tensor.int,
torch.Tensor.long,
torch.Tensor.cuda,
torch.Tensor.cpu,
torch.Tensor.to,
torch.Tensor.get_device,
torch._has_compatible_shallow_copy_type,
]
def materialize(self, shape, device=None, dtype=None):
r"""Create a Parameter or Tensor with the same properties of the uninitialized one.
Given a shape, it materializes a parameter in the same device
and with the same `dtype` as the current one or the specified ones in the
arguments.
Args:
shape : (tuple): the shape for the materialized tensor.
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module. Optional.
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module. Optional.
"""
if device is None:
device = self.data.device
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
self.__class__ = self.cls_to_become
@property
def shape(self):
raise RuntimeError(
'Can\'t access the shape of an uninitialized parameter or buffer. '
'This error usually happens in `load_state_dict` when trying to load '
'an uninitialized parameter into an initialized one. '
'Call `forward` to initialize the parameters before accessing their attributes.')
def share_memory_(self):
raise RuntimeError(
'Can\'t share memory on an uninitialized parameter or buffer. '
'Call `forward` to initialize the parameters before calling '
'`module.share_memory()`.')
def __repr__(self):
return f'<{self.__class__.__name__}>'
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
self.__class__,
(self.requires_grad,)
)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# method-wrapper is to detect access to Tensor properties that are
# wrapped in descriptors
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
'Attempted to use an uninitialized parameter in {}. '
'This error happens when you are using a `LazyModule` or '
'explicitly manipulating `torch.nn.parameter.{}` '
'objects. When using LazyModules Call `forward` with a dummy batch '
'to initialize the parameters before calling torch functions'.format(func, cls.__name__))
def is_lazy(param):
return isinstance(param, UninitializedTensorMixin)
class UninitializedParameter(UninitializedTensorMixin, Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlike a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
The default device or dtype to use when the parameter is materialized can be set
during construction using e.g. ``device='cuda'``.
"""
cls_to_become = Parameter
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
data = torch.empty(0, **factory_kwargs)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.requires_grad, self.data.device, self.data.dtype)
memo[id(self)] = result
return result
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
r"""A buffer that is not initialized.
Unitialized Buffer is a a special case of :class:`torch.Tensor`
where the shape of the data is still unknown.
Unlike a :class:`torch.Tensor`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.Tensor`.
The default device or dtype to use when the buffer is materialized can be set
during construction using e.g. ``device='cuda'``.
"""
cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
data = torch.empty(0, **factory_kwargs)
return torch.Tensor._make_subclass(cls, data, requires_grad)