mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add assign argument to torch.Tensor.module_load (#121158)
Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict` Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved Pull Request resolved: https://github.com/pytorch/pytorch/pull/121158 Approved by: https://github.com/albanD ghstack dependencies: #121157
This commit is contained in:
parent
27389e03f0
commit
4b3903379a
|
|
@ -261,6 +261,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
def forward(self, input):
|
||||
return self.x + self.bn(self.fc1(input))
|
||||
|
||||
swap = torch.__future__.get_swap_module_params_on_conversion()
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict(keep_vars=keep_vars)
|
||||
for v in state_dict.values():
|
||||
|
|
@ -276,16 +277,21 @@ class TestLoadStateDict(NNTestCase):
|
|||
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
|
||||
for key in state_dict.keys():
|
||||
if key in net_meta._parameters:
|
||||
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
|
||||
if keep_vars:
|
||||
if keep_vars and not swap:
|
||||
# state_dict[key] is an nn.Parameter
|
||||
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
|
||||
else:
|
||||
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
|
||||
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
|
||||
self.assertEqual(state_dict[key], net_meta_state_dict[key])
|
||||
if swap:
|
||||
self.assertTrue(net_meta_state_dict[key] is net_meta_state_dict_old[key])
|
||||
else:
|
||||
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
|
||||
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
|
||||
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
|
||||
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
|
||||
self.assertEqual(state_dict[key], net_meta_state_dict[key])
|
||||
elif key in net_meta._buffers and key not in net_meta._non_persistent_buffers_set:
|
||||
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
|
||||
self.assertEqual(state_dict[key], net_meta_state_dict[key])
|
||||
|
||||
# Make sure that ordering of parameters and buffers is preserved
|
||||
net_named_parameters = net.named_parameters()
|
||||
|
|
@ -391,19 +397,32 @@ class TestLoadStateDict(NNTestCase):
|
|||
def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
def module_load(dest, src):
|
||||
# always convert src to cls
|
||||
def module_load(dest, src, assign=False):
|
||||
if isinstance(dest, cls):
|
||||
if type(src) is torch.Tensor:
|
||||
return cls(src)
|
||||
elif type(src) is cls:
|
||||
if assign:
|
||||
return src.detach()
|
||||
else:
|
||||
if type(src) is torch.Tensor:
|
||||
return cls(src)
|
||||
elif type(src) is cls:
|
||||
return src.detach()
|
||||
else:
|
||||
if isinstance(src, MyWrapperLoadTensor):
|
||||
return cls(src._data)
|
||||
return cls(src)
|
||||
else:
|
||||
assert isinstance(src, cls), f"Expected isinstance(src, {cls}) but got {type(src)}"
|
||||
assert type(dest) == torch.Tensor or type(dest) == torch.nn.Parameter or issubclass(cls, type(dest))
|
||||
if assign:
|
||||
return src.detach()
|
||||
else:
|
||||
if isinstance(src, MyWrapperLoadTensor):
|
||||
return cls(src._data)
|
||||
return cls(src)
|
||||
else:
|
||||
return src.detach()
|
||||
if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
|
||||
return type(dest)(src._data)
|
||||
else:
|
||||
return src._data.detach()
|
||||
else:
|
||||
return torch.Tensor(src)
|
||||
|
||||
if func is torch.Tensor.module_load:
|
||||
return module_load(*args, **kwargs)
|
||||
|
|
@ -478,7 +497,8 @@ class TestLoadStateDictSwap(TestCase):
|
|||
@skipIfCrossRef
|
||||
@skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
|
||||
@swap([True])
|
||||
def test_swap_subclass(self):
|
||||
@parametrize("assign", [True, False])
|
||||
def test_swap_subclass(self, assign):
|
||||
|
||||
def _create_model(subclass=None):
|
||||
m = torch.nn.Linear(2, 3, bias=False)
|
||||
|
|
@ -491,24 +511,20 @@ class TestLoadStateDictSwap(TestCase):
|
|||
def _test(m_subclass=None, sd_subclass=None):
|
||||
m = _create_model(m_subclass)
|
||||
sd = _create_model(sd_subclass).state_dict()
|
||||
sd = sd
|
||||
m.load_state_dict(sd)
|
||||
m.load_state_dict(sd, assign=assign)
|
||||
self.assertEqual(m.weight, sd['weight'])
|
||||
self.assertEqual(m.buf, sd['buf'])
|
||||
self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
|
||||
self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
|
||||
|
||||
weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
|
||||
if m_subclass is not None and sd_subclass is not None:
|
||||
# handler of subclass takes precedence over superclass
|
||||
if issubclass(sd_subclass, m_subclass):
|
||||
if assign:
|
||||
if sd_subclass is not None:
|
||||
weight_type, buf_type = (sd_subclass, sd_subclass)
|
||||
else:
|
||||
else:
|
||||
if m_subclass is not None:
|
||||
weight_type, buf_type = (m_subclass, m_subclass)
|
||||
elif m_subclass is not None:
|
||||
weight_type, buf_type = (m_subclass, m_subclass)
|
||||
elif sd_subclass is not None:
|
||||
weight_type, buf_type = (sd_subclass, sd_subclass)
|
||||
|
||||
self.assertTrue(type(m.weight) is weight_type)
|
||||
self.assertTrue(type(m.buf) is buf_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -711,7 +711,7 @@ class Tensor(torch._C.TensorBase):
|
|||
self._typed_storage()._share_memory_()
|
||||
return self
|
||||
|
||||
def module_load(self, other):
|
||||
def module_load(self, other, assign=False):
|
||||
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
|
||||
|
||||
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
|
||||
|
|
@ -723,16 +723,23 @@ class Tensor(torch._C.TensorBase):
|
|||
|
||||
.. note::
|
||||
This method should always return a new object that is not ``self`` or ``other``.
|
||||
For example, the default implementation returns ``self.copy_(other).detach()``.
|
||||
For example, the default implementation returns ``self.copy_(other).detach()``
|
||||
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
|
||||
|
||||
Args:
|
||||
other (Tensor): value in state dict with key corresponding to ``self``
|
||||
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
|
||||
|
||||
"""
|
||||
if has_torch_function_variadic(self, other):
|
||||
return handle_torch_function(Tensor.module_load, (self, other), self, other)
|
||||
# In the default case, swap_tensors becomes a no-op
|
||||
return self.copy_(other).detach()
|
||||
return handle_torch_function(
|
||||
Tensor.module_load, (self, other), self, other, assign=assign
|
||||
)
|
||||
|
||||
if assign:
|
||||
return other.detach()
|
||||
else:
|
||||
return self.copy_(other).detach()
|
||||
|
||||
def __reversed__(self):
|
||||
r"""Reverses the tensor along dimension 0."""
|
||||
|
|
|
|||
|
|
@ -2046,7 +2046,19 @@ class Module:
|
|||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
if assign_to_params_buffers:
|
||||
if use_swap_tensors:
|
||||
new_input_param = param.module_load(input_param, assign=assign_to_params_buffers)
|
||||
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
|
||||
raise RuntimeError("module_load returned one of self or other, please .detach() "
|
||||
"the result if returning one of the inputs in module_load")
|
||||
if (isinstance(param, torch.nn.Parameter)):
|
||||
if not isinstance(new_input_param, torch.nn.Parameter):
|
||||
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad)
|
||||
else:
|
||||
new_input_param.requires_grad_(param.requires_grad)
|
||||
torch.utils.swap_tensors(param, new_input_param)
|
||||
del new_input_param
|
||||
elif assign_to_params_buffers:
|
||||
# Shape checks are already done above
|
||||
if (isinstance(param, torch.nn.Parameter)):
|
||||
if not isinstance(input_param, torch.nn.Parameter):
|
||||
|
|
@ -2054,17 +2066,6 @@ class Module:
|
|||
else:
|
||||
input_param.requires_grad_(param.requires_grad)
|
||||
setattr(self, name, input_param)
|
||||
elif use_swap_tensors:
|
||||
param_requires_grad = param.requires_grad
|
||||
new_input_param = param.module_load(input_param)
|
||||
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
|
||||
raise RuntimeError("module_load returned one of self or other, please .detach() "
|
||||
"the result if returning one of the inputs in module_load")
|
||||
if (isinstance(param, torch.nn.Parameter) and
|
||||
not isinstance(new_input_param, torch.nn.Parameter)):
|
||||
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param_requires_grad)
|
||||
torch.utils.swap_tensors(param, new_input_param)
|
||||
del new_input_param
|
||||
else:
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
|
|
@ -2104,7 +2105,8 @@ class Module:
|
|||
|
||||
.. warning::
|
||||
If :attr:`assign` is ``True`` the optimizer must be created after
|
||||
the call to :attr:`load_state_dict`.
|
||||
the call to :attr:`load_state_dict` unless
|
||||
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
|
|
@ -2112,12 +2114,11 @@ class Module:
|
|||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
||||
assign (bool, optional): whether to assign items in the state
|
||||
dictionary to their corresponding keys in the module instead
|
||||
of copying them inplace into the module's current parameters and buffers.
|
||||
When ``False``, the properties of the tensors in the current
|
||||
module are preserved while when ``True``, the properties of the
|
||||
Tensors in the state dict are preserved.
|
||||
assign (bool, optional): When ``False``, the properties of the tensors
|
||||
in the current module are preserved while when ``True``, the
|
||||
properties of the Tensors in the state dict are preserved. The only
|
||||
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
|
||||
for which the value from the module is preserved.
|
||||
Default: ``False``
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -1367,7 +1367,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
Tensor.map_: lambda self, tensor, callable: -1,
|
||||
Tensor.map2_: lambda self, x, y, callable: -1,
|
||||
Tensor.mm: lambda self, mat2: -1,
|
||||
Tensor.module_load: lambda self, other: -1,
|
||||
Tensor.module_load: lambda self, other, assign=False: -1,
|
||||
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
|
||||
Tensor.ndimension: lambda self: -1,
|
||||
Tensor.nelement: lambda self: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user