Add deepcopy functionality to parametrized modules (#80811)

Fixes #69413

After applying parametrization to any `nn.Module` we lose the ability to create a deepcopy of it e.g. it makes it impossible to wrap a module by an `AveragedModel`.
Specifically, the problem is that the `deepcopy` tries to invoke `__getstate__` if object hasn't implemented its own `__deepcopy__` magic method. But we don't allow serialization of the parametrized modules: `__getstate__` raises an error.
My solution is just to create a default `__deepcopy__` method when it doesn't exist yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80811
Approved by: https://github.com/pearu, https://github.com/albanD
This commit is contained in:
n.zhuravlev 2022-07-15 09:06:45 +00:00 committed by PyTorch MergeBot
parent b95ae2909e
commit 7af0200a46
2 changed files with 72 additions and 3 deletions

View File

@ -3297,6 +3297,54 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
parametrize.type_before_parametrizations(model) == original_type
)
def test_deepcopy_after_parametrization(self):
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
class AddOne(nn.Module):
def forward(self, x):
return x + 1.0
class ModelWithoutDeepcopy(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True)
self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True)
self.attr = [1.0, 2.0, 3.0, 4.0]
class ActualModel(ModelWithoutDeepcopy):
# Emulate custom implementation of the deepcopying.
def __deepcopy__(self, memo):
result = self.__new__(self.__class__)
memo[id(self)] = result
result.__dict__ = deepcopy(self.__dict__, memo)
return result
def check_deepcopy(m1: nn.Module, m2: nn.Module):
w1 = m1.parametrizations.weight.original
w2 = m2.parametrizations.weight.original
b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias
b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias
# Weights, biases and attributes should be equal but they must be different objects.
self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
self.assertIsNot(m1, m2)
self.assertEqual(w1, w2)
self.assertIsNot(w1, w2)
self.assertEqual(b1, b2)
self.assertIsNot(b1, b2)
self.assertEqual(m1.attr, m2.attr)
self.assertIsNot(m1.attr, m2.attr)
for model in (ModelWithoutDeepcopy(), ActualModel()):
# General check that we are able to create deepcopy.
parametrize.register_parametrization(model, "weight", AddOne())
check_deepcopy(model, deepcopy(model))
# Check that this works on models with several parametrized tensors.
parametrize.register_parametrization(model, "bias", AddOne())
check_deepcopy(model, deepcopy(model))
# Check that this works on models where tensors have more than one parametrization.
parametrize.register_parametrization(model, "weight", AddOne())
check_deepcopy(model, deepcopy(model))
def test_transfer_parametrizations_and_params(self):
r"""Test that all parametrizations and their associated parameters are transferred."""

View File

@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter
from torch import Tensor
import collections
import copyreg
from copy import deepcopy
from contextlib import contextmanager
from typing import Union, Optional, Dict, Tuple, Sequence
@ -285,6 +287,21 @@ def _inject_new_class(module: Module) -> None:
"""
cls = module.__class__
def default_deepcopy(self, memo):
# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
obj = memo.get(id(self), None)
if obj is not None:
return obj
replica = self.__new__(self.__class__)
memo[id(self)] = replica
replica.__dict__ = deepcopy(self.__dict__, memo)
# Also save all slots if they exist.
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
for slot in slots_to_save:
if hasattr(self, slot):
setattr(replica, slot, deepcopy(getattr(self, slot), memo))
return replica
def getstate(self):
raise RuntimeError(
"Serialization of parametrized modules is only "
@ -293,12 +310,16 @@ def _inject_new_class(module: Module) -> None:
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
dct = {"__getstate__": getstate}
# We don't allow serialization of parametrized modules but should still allow deepcopying.
# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
if not hasattr(cls, "__deepcopy__"):
dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
{
"__getstate__": getstate,
},
dct,
)
module.__class__ = param_cls