mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add deepcopy functionality to parametrized modules (#80811)
Fixes #69413 After applying parametrization to any `nn.Module` we lose the ability to create a deepcopy of it e.g. it makes it impossible to wrap a module by an `AveragedModel`. Specifically, the problem is that the `deepcopy` tries to invoke `__getstate__` if object hasn't implemented its own `__deepcopy__` magic method. But we don't allow serialization of the parametrized modules: `__getstate__` raises an error. My solution is just to create a default `__deepcopy__` method when it doesn't exist yet. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80811 Approved by: https://github.com/pearu, https://github.com/albanD
This commit is contained in:
parent
b95ae2909e
commit
7af0200a46
|
|
@ -3297,6 +3297,54 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
parametrize.type_before_parametrizations(model) == original_type
|
||||
)
|
||||
|
||||
def test_deepcopy_after_parametrization(self):
|
||||
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
|
||||
|
||||
class AddOne(nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1.0
|
||||
|
||||
class ModelWithoutDeepcopy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True)
|
||||
self.attr = [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
class ActualModel(ModelWithoutDeepcopy):
|
||||
# Emulate custom implementation of the deepcopying.
|
||||
def __deepcopy__(self, memo):
|
||||
result = self.__new__(self.__class__)
|
||||
memo[id(self)] = result
|
||||
result.__dict__ = deepcopy(self.__dict__, memo)
|
||||
return result
|
||||
|
||||
def check_deepcopy(m1: nn.Module, m2: nn.Module):
|
||||
w1 = m1.parametrizations.weight.original
|
||||
w2 = m2.parametrizations.weight.original
|
||||
b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias
|
||||
b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias
|
||||
# Weights, biases and attributes should be equal but they must be different objects.
|
||||
self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
|
||||
self.assertIsNot(m1, m2)
|
||||
self.assertEqual(w1, w2)
|
||||
self.assertIsNot(w1, w2)
|
||||
self.assertEqual(b1, b2)
|
||||
self.assertIsNot(b1, b2)
|
||||
self.assertEqual(m1.attr, m2.attr)
|
||||
self.assertIsNot(m1.attr, m2.attr)
|
||||
|
||||
for model in (ModelWithoutDeepcopy(), ActualModel()):
|
||||
# General check that we are able to create deepcopy.
|
||||
parametrize.register_parametrization(model, "weight", AddOne())
|
||||
check_deepcopy(model, deepcopy(model))
|
||||
# Check that this works on models with several parametrized tensors.
|
||||
parametrize.register_parametrization(model, "bias", AddOne())
|
||||
check_deepcopy(model, deepcopy(model))
|
||||
# Check that this works on models where tensors have more than one parametrization.
|
||||
parametrize.register_parametrization(model, "weight", AddOne())
|
||||
check_deepcopy(model, deepcopy(model))
|
||||
|
||||
def test_transfer_parametrizations_and_params(self):
|
||||
r"""Test that all parametrizations and their associated parameters are transferred."""
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter
|
|||
from torch import Tensor
|
||||
|
||||
import collections
|
||||
import copyreg
|
||||
from copy import deepcopy
|
||||
from contextlib import contextmanager
|
||||
from typing import Union, Optional, Dict, Tuple, Sequence
|
||||
|
||||
|
|
@ -285,6 +287,21 @@ def _inject_new_class(module: Module) -> None:
|
|||
"""
|
||||
cls = module.__class__
|
||||
|
||||
def default_deepcopy(self, memo):
|
||||
# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
|
||||
obj = memo.get(id(self), None)
|
||||
if obj is not None:
|
||||
return obj
|
||||
replica = self.__new__(self.__class__)
|
||||
memo[id(self)] = replica
|
||||
replica.__dict__ = deepcopy(self.__dict__, memo)
|
||||
# Also save all slots if they exist.
|
||||
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
|
||||
for slot in slots_to_save:
|
||||
if hasattr(self, slot):
|
||||
setattr(replica, slot, deepcopy(getattr(self, slot), memo))
|
||||
return replica
|
||||
|
||||
def getstate(self):
|
||||
raise RuntimeError(
|
||||
"Serialization of parametrized modules is only "
|
||||
|
|
@ -293,12 +310,16 @@ def _inject_new_class(module: Module) -> None:
|
|||
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
|
||||
)
|
||||
|
||||
dct = {"__getstate__": getstate}
|
||||
# We don't allow serialization of parametrized modules but should still allow deepcopying.
|
||||
# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
|
||||
if not hasattr(cls, "__deepcopy__"):
|
||||
dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
|
||||
|
||||
param_cls = type(
|
||||
f"Parametrized{cls.__name__}",
|
||||
(cls,),
|
||||
{
|
||||
"__getstate__": getstate,
|
||||
},
|
||||
dct,
|
||||
)
|
||||
|
||||
module.__class__ = param_cls
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user