Migrate load_state_dict hook tests to OptimizerInfo (#119310)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119310
Approved by: https://github.com/albanD
ghstack dependencies: #119283, #119288, #119299, #119308
This commit is contained in:
Jane Xu 2024-02-06 12:30:41 -08:00 committed by PyTorch MergeBot
parent 0320e62255
commit 059994d2b7
2 changed files with 82 additions and 57 deletions

View File

@ -3,12 +3,11 @@
import unittest import unittest
import functools import functools
import itertools import itertools
from copy import deepcopy
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import ( from torch.optim import (
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, Optimizer Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
) )
from torch.optim.lr_scheduler import ( from torch.optim.lr_scheduler import (
StepLR, StepLR,
@ -28,7 +27,6 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_cuda import TEST_CUDA
from typing import Dict, Any
from unittest.mock import patch from unittest.mock import patch
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
@ -747,60 +745,6 @@ class TestOptim(TestCase):
maximize=False, maximize=False,
) )
@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True
def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()
# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this # Ignored is the list of values in `opt_differentiable_state`, we do this

View File

@ -988,6 +988,87 @@ class TestOptimRenewed(TestCase):
self.assertTrue(state_dict["ran_state_dict_pre_hook"]) self.assertTrue(state_dict["ran_state_dict_pre_hook"])
@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
state_dict = optim.state_dict()
# usually one would have a new optim instance here, but it's all the same here
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook1)
optim.load_state_dict(state_dict)
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2, prepend=True)
optim.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
optim.load_state_dict(optim.state_dict())
self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2)
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
optim.load_state_dict(optim.state_dict())
self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
@optims(optim_db, dtypes=[torch.float32]) @optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info): def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):