Add assign argument to torch.Tensor.module_load (#121158)

Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict`

Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121158
Approved by: https://github.com/albanD
ghstack dependencies: #121157
This commit is contained in:
Mikayla Gawarecki 2024-03-05 14:23:54 -08:00 committed by PyTorch MergeBot
parent 27389e03f0
commit 4b3903379a
4 changed files with 74 additions and 50 deletions

View File

@ -261,6 +261,7 @@ class TestLoadStateDict(NNTestCase):
def forward(self, input):
return self.x + self.bn(self.fc1(input))
swap = torch.__future__.get_swap_module_params_on_conversion()
net = MyModule()
state_dict = net.state_dict(keep_vars=keep_vars)
for v in state_dict.values():
@ -276,16 +277,21 @@ class TestLoadStateDict(NNTestCase):
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
for key in state_dict.keys():
if key in net_meta._parameters:
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
if keep_vars:
if keep_vars and not swap:
# state_dict[key] is an nn.Parameter
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
else:
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
self.assertEqual(state_dict[key], net_meta_state_dict[key])
if swap:
self.assertTrue(net_meta_state_dict[key] is net_meta_state_dict_old[key])
else:
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
self.assertEqual(state_dict[key], net_meta_state_dict[key])
elif key in net_meta._buffers and key not in net_meta._non_persistent_buffers_set:
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
self.assertEqual(state_dict[key], net_meta_state_dict[key])
# Make sure that ordering of parameters and buffers is preserved
net_named_parameters = net.named_parameters()
@ -391,19 +397,32 @@ class TestLoadStateDict(NNTestCase):
def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
def module_load(dest, src):
# always convert src to cls
def module_load(dest, src, assign=False):
if isinstance(dest, cls):
if type(src) is torch.Tensor:
return cls(src)
elif type(src) is cls:
if assign:
return src.detach()
else:
if type(src) is torch.Tensor:
return cls(src)
elif type(src) is cls:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
return cls(src._data)
return cls(src)
else:
assert isinstance(src, cls), f"Expected isinstance(src, {cls}) but got {type(src)}"
assert type(dest) == torch.Tensor or type(dest) == torch.nn.Parameter or issubclass(cls, type(dest))
if assign:
return src.detach()
else:
if isinstance(src, MyWrapperLoadTensor):
return cls(src._data)
return cls(src)
else:
return src.detach()
if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
return type(dest)(src._data)
else:
return src._data.detach()
else:
return torch.Tensor(src)
if func is torch.Tensor.module_load:
return module_load(*args, **kwargs)
@ -478,7 +497,8 @@ class TestLoadStateDictSwap(TestCase):
@skipIfCrossRef
@skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
@swap([True])
def test_swap_subclass(self):
@parametrize("assign", [True, False])
def test_swap_subclass(self, assign):
def _create_model(subclass=None):
m = torch.nn.Linear(2, 3, bias=False)
@ -491,24 +511,20 @@ class TestLoadStateDictSwap(TestCase):
def _test(m_subclass=None, sd_subclass=None):
m = _create_model(m_subclass)
sd = _create_model(sd_subclass).state_dict()
sd = sd
m.load_state_dict(sd)
m.load_state_dict(sd, assign=assign)
self.assertEqual(m.weight, sd['weight'])
self.assertEqual(m.buf, sd['buf'])
self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
if m_subclass is not None and sd_subclass is not None:
# handler of subclass takes precedence over superclass
if issubclass(sd_subclass, m_subclass):
if assign:
if sd_subclass is not None:
weight_type, buf_type = (sd_subclass, sd_subclass)
else:
else:
if m_subclass is not None:
weight_type, buf_type = (m_subclass, m_subclass)
elif m_subclass is not None:
weight_type, buf_type = (m_subclass, m_subclass)
elif sd_subclass is not None:
weight_type, buf_type = (sd_subclass, sd_subclass)
self.assertTrue(type(m.weight) is weight_type)
self.assertTrue(type(m.buf) is buf_type)

View File

@ -711,7 +711,7 @@ class Tensor(torch._C.TensorBase):
self._typed_storage()._share_memory_()
return self
def module_load(self, other):
def module_load(self, other, assign=False):
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
@ -723,16 +723,23 @@ class Tensor(torch._C.TensorBase):
.. note::
This method should always return a new object that is not ``self`` or ``other``.
For example, the default implementation returns ``self.copy_(other).detach()``.
For example, the default implementation returns ``self.copy_(other).detach()``
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
Args:
other (Tensor): value in state dict with key corresponding to ``self``
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
"""
if has_torch_function_variadic(self, other):
return handle_torch_function(Tensor.module_load, (self, other), self, other)
# In the default case, swap_tensors becomes a no-op
return self.copy_(other).detach()
return handle_torch_function(
Tensor.module_load, (self, other), self, other, assign=assign
)
if assign:
return other.detach()
else:
return self.copy_(other).detach()
def __reversed__(self):
r"""Reverses the tensor along dimension 0."""

View File

@ -2046,7 +2046,19 @@ class Module:
try:
with torch.no_grad():
if assign_to_params_buffers:
if use_swap_tensors:
new_input_param = param.module_load(input_param, assign=assign_to_params_buffers)
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
raise RuntimeError("module_load returned one of self or other, please .detach() "
"the result if returning one of the inputs in module_load")
if (isinstance(param, torch.nn.Parameter)):
if not isinstance(new_input_param, torch.nn.Parameter):
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad)
else:
new_input_param.requires_grad_(param.requires_grad)
torch.utils.swap_tensors(param, new_input_param)
del new_input_param
elif assign_to_params_buffers:
# Shape checks are already done above
if (isinstance(param, torch.nn.Parameter)):
if not isinstance(input_param, torch.nn.Parameter):
@ -2054,17 +2066,6 @@ class Module:
else:
input_param.requires_grad_(param.requires_grad)
setattr(self, name, input_param)
elif use_swap_tensors:
param_requires_grad = param.requires_grad
new_input_param = param.module_load(input_param)
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
raise RuntimeError("module_load returned one of self or other, please .detach() "
"the result if returning one of the inputs in module_load")
if (isinstance(param, torch.nn.Parameter) and
not isinstance(new_input_param, torch.nn.Parameter)):
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param_requires_grad)
torch.utils.swap_tensors(param, new_input_param)
del new_input_param
else:
param.copy_(input_param)
except Exception as ex:
@ -2104,7 +2105,8 @@ class Module:
.. warning::
If :attr:`assign` is ``True`` the optimizer must be created after
the call to :attr:`load_state_dict`.
the call to :attr:`load_state_dict` unless
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
Args:
state_dict (dict): a dict containing parameters and
@ -2112,12 +2114,11 @@ class Module:
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): whether to assign items in the state
dictionary to their corresponding keys in the module instead
of copying them inplace into the module's current parameters and buffers.
When ``False``, the properties of the tensors in the current
module are preserved while when ``True``, the properties of the
Tensors in the state dict are preserved.
assign (bool, optional): When ``False``, the properties of the tensors
in the current module are preserved while when ``True``, the
properties of the Tensors in the state dict are preserved. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
for which the value from the module is preserved.
Default: ``False``
Returns:

View File

@ -1367,7 +1367,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.map_: lambda self, tensor, callable: -1,
Tensor.map2_: lambda self, x, y, callable: -1,
Tensor.mm: lambda self, mat2: -1,
Tensor.module_load: lambda self, other: -1,
Tensor.module_load: lambda self, other, assign=False: -1,
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
Tensor.ndimension: lambda self: -1,
Tensor.nelement: lambda self: -1,