diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 146ae73b1fd..fe27c24d0c0 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -747,20 +747,6 @@ class TestOptim(TestCase): maximize=False, ) - - @staticmethod - def _state_dict_pre_hook(optimizer: Optimizer) -> None: - optimizer.state["test"] = 1 - - @staticmethod - def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]: - if "test" in state_dict["state"]: - state_dict["state"].pop("test") - state_dict["ran_state_dict_pre_hook"] = True - else: - state_dict["ran_state_dict_pre_hook"] = False - return state_dict - @staticmethod def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None: state_dict["param_groups"][0]["lr"] = 0.002 @@ -778,32 +764,6 @@ class TestOptim(TestCase): optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003 optimizer.state["ran_load_state_dict_post_hook"] = True - def test_state_dict_pre_hook(self): - param = torch.rand(2, 3, requires_grad=True) - param.grad = torch.rand(2, 3, requires_grad=True) - opt = SGD([param], lr=0.001) - opt.register_state_dict_pre_hook(self._state_dict_pre_hook) - state_dict = opt.state_dict() - self.assertEqual(state_dict["state"]["test"], 1) - - def test_state_dict_post_hook(self): - param = torch.rand(2, 3, requires_grad=True) - param.grad = torch.rand(2, 3, requires_grad=True) - opt = SGD([param], lr=0.001) - opt.register_state_dict_post_hook(self._state_dict_post_hook) - state_dict = opt.state_dict() - self.assertEqual(state_dict["ran_state_dict_pre_hook"], False) - - def test_state_dict_pre_post_hook(self): - param = torch.rand(2, 3, requires_grad=True) - param.grad = torch.rand(2, 3, requires_grad=True) - opt = SGD([param], lr=0.001) - opt.register_state_dict_pre_hook(self._state_dict_pre_hook) - opt.register_state_dict_post_hook(self._state_dict_post_hook) - state_dict = opt.state_dict() - self.assertFalse("test" in state_dict["state"]) - self.assertEqual(state_dict["ran_state_dict_pre_hook"], True) - def test_load_state_dict_pre_hook_and_prepend(self): param = torch.rand(2, 3, requires_grad=True) param.grad = torch.rand(2, 3, requires_grad=True) diff --git a/test/test_optim.py b/test/test_optim.py index 86fde42e20a..14ba15313eb 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -923,6 +923,71 @@ class TestOptimRenewed(TestCase): self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict()) + @staticmethod + def _state_dict_pre_hook(optimizer: Optimizer) -> None: + optimizer.state["test"] = 1 + + + @staticmethod + def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]: + if "test" in state_dict["state"]: + state_dict["state"].pop("test") + state_dict["ran_state_dict_pre_hook"] = True + else: + state_dict["ran_state_dict_pre_hook"] = False + return state_dict + + + @optims(optim_db, dtypes=[torch.float32]) + def test_state_dict_pre_hook(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info) + for optim_input in all_optim_inputs: + if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False) + and not optim_input.kwargs.get("foreach", False)): + continue + + param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) + optim = optim_cls([param], **optim_input.kwargs) + optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) + state_dict = optim.state_dict() + self.assertEqual(state_dict["state"]["test"], 1) + + + @optims(optim_db, dtypes=[torch.float32]) + def test_state_dict_post_hook(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info) + for optim_input in all_optim_inputs: + if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False) + and not optim_input.kwargs.get("foreach", False)): + continue + + param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) + optim = optim_cls([param], **optim_input.kwargs) + optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) + state_dict = optim.state_dict() + self.assertFalse(state_dict["ran_state_dict_pre_hook"]) + + + @optims(optim_db, dtypes=[torch.float32]) + def test_state_dict_pre_post_hook(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info) + for optim_input in all_optim_inputs: + if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False) + and not optim_input.kwargs.get("foreach", False)): + continue + + param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) + optim = optim_cls([param], **optim_input.kwargs) + optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) + optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) + state_dict = optim.state_dict() + self.assertFalse("test" in state_dict["state"]) + self.assertTrue(state_dict["ran_state_dict_pre_hook"]) + + @optims(optim_db, dtypes=[torch.float32]) def test_step_post_hook(self, device, dtype, optim_info): def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):