mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate load_state_dict hook tests to OptimizerInfo (#119310)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119310 Approved by: https://github.com/albanD ghstack dependencies: #119283, #119288, #119299, #119308
This commit is contained in:
parent
0320e62255
commit
059994d2b7
|
|
@ -3,12 +3,11 @@
|
||||||
import unittest
|
import unittest
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import (
|
from torch.optim import (
|
||||||
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, Optimizer
|
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
|
||||||
)
|
)
|
||||||
from torch.optim.lr_scheduler import (
|
from torch.optim.lr_scheduler import (
|
||||||
StepLR,
|
StepLR,
|
||||||
|
|
@ -28,7 +27,6 @@ from torch.testing._internal.common_utils import (
|
||||||
|
|
||||||
|
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from typing import Dict, Any
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
|
|
@ -747,60 +745,6 @@ class TestOptim(TestCase):
|
||||||
maximize=False,
|
maximize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
|
|
||||||
state_dict["param_groups"][0]["lr"] = 0.002
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# The typical use case for returning a state dict is to drastically modify the state dict.
|
|
||||||
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
|
|
||||||
my_state_dict = deepcopy(state_dict)
|
|
||||||
my_state_dict["param_groups"][0]["lr"] = 0.003
|
|
||||||
return my_state_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
|
|
||||||
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
|
|
||||||
optimizer.state["ran_load_state_dict_post_hook"] = True
|
|
||||||
|
|
||||||
def test_load_state_dict_pre_hook_and_prepend(self):
|
|
||||||
param = torch.rand(2, 3, requires_grad=True)
|
|
||||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
|
||||||
opt = SGD([param], lr=0.001)
|
|
||||||
state_dict = opt.state_dict()
|
|
||||||
|
|
||||||
# usually one would have a new opt instance here, but it's all the same here
|
|
||||||
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
|
|
||||||
opt.load_state_dict(state_dict)
|
|
||||||
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
|
|
||||||
|
|
||||||
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
|
|
||||||
opt.load_state_dict(state_dict)
|
|
||||||
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
|
|
||||||
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
|
|
||||||
|
|
||||||
def test_load_state_dict_post_hook(self):
|
|
||||||
param = torch.rand(2, 3, requires_grad=True)
|
|
||||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
|
||||||
opt = SGD([param], lr=0.001)
|
|
||||||
|
|
||||||
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
|
|
||||||
opt.load_state_dict(opt.state_dict())
|
|
||||||
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
|
|
||||||
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
|
|
||||||
|
|
||||||
def test_load_state_dict_pre_post_hook(self):
|
|
||||||
param = torch.rand(2, 3, requires_grad=True)
|
|
||||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
|
||||||
opt = SGD([param], lr=0.001)
|
|
||||||
|
|
||||||
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
|
|
||||||
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
|
|
||||||
opt.load_state_dict(opt.state_dict())
|
|
||||||
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
|
|
||||||
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
|
|
||||||
|
|
||||||
|
|
||||||
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
||||||
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
||||||
|
|
|
||||||
|
|
@ -988,6 +988,87 @@ class TestOptimRenewed(TestCase):
|
||||||
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
|
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
|
||||||
|
state_dict["param_groups"][0]["lr"] = 0.002
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
# The typical use case for returning a state dict is to drastically modify the state dict.
|
||||||
|
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
|
||||||
|
my_state_dict = deepcopy(state_dict)
|
||||||
|
my_state_dict["param_groups"][0]["lr"] = 0.003
|
||||||
|
return my_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
|
||||||
|
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
|
||||||
|
optimizer.state["ran_load_state_dict_post_hook"] = True
|
||||||
|
|
||||||
|
|
||||||
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
|
def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
|
||||||
|
optim_cls = optim_info.optim_cls
|
||||||
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||||
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
optim = optim_cls([param], **optim_input.kwargs)
|
||||||
|
state_dict = optim.state_dict()
|
||||||
|
|
||||||
|
# usually one would have a new optim instance here, but it's all the same here
|
||||||
|
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook1)
|
||||||
|
optim.load_state_dict(state_dict)
|
||||||
|
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
|
||||||
|
|
||||||
|
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2, prepend=True)
|
||||||
|
optim.load_state_dict(state_dict)
|
||||||
|
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
|
||||||
|
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
|
||||||
|
|
||||||
|
|
||||||
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
|
def test_load_state_dict_post_hook(self, device, dtype, optim_info):
|
||||||
|
optim_cls = optim_info.optim_cls
|
||||||
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||||
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
optim = optim_cls([param], **optim_input.kwargs)
|
||||||
|
|
||||||
|
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
|
||||||
|
optim.load_state_dict(optim.state_dict())
|
||||||
|
self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
|
||||||
|
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
|
||||||
|
|
||||||
|
|
||||||
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
|
def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
|
||||||
|
optim_cls = optim_info.optim_cls
|
||||||
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||||
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
optim = optim_cls([param], **optim_input.kwargs)
|
||||||
|
|
||||||
|
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2)
|
||||||
|
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
|
||||||
|
optim.load_state_dict(optim.state_dict())
|
||||||
|
self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
|
||||||
|
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
|
||||||
|
|
||||||
|
|
||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
def test_step_post_hook(self, device, dtype, optim_info):
|
def test_step_post_hook(self, device, dtype, optim_info):
|
||||||
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user