Migrate test_state_dict hooks to OptimizerInfo (#119308)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119308
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #119283, #119288, #119299
This commit is contained in:
Jane Xu 2024-02-06 12:20:21 -08:00 committed by PyTorch MergeBot
parent 5c46600f84
commit 0320e62255
2 changed files with 65 additions and 40 deletions

View File

@ -747,20 +747,6 @@ class TestOptim(TestCase):
maximize=False, maximize=False,
) )
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@staticmethod @staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None: def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002 state_dict["param_groups"][0]["lr"] = 0.002
@ -778,32 +764,6 @@ class TestOptim(TestCase):
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003 optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True optimizer.state["ran_load_state_dict_post_hook"] = True
def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)
def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)
def test_load_state_dict_pre_hook_and_prepend(self): def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True) param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True) param.grad = torch.rand(2, 3, requires_grad=True)

View File

@ -923,6 +923,71 @@ class TestOptimRenewed(TestCase):
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict()) self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_pre_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
state_dict = optim.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
state_dict = optim.state_dict()
self.assertFalse(state_dict["ran_state_dict_pre_hook"])
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
state_dict = optim.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
@optims(optim_db, dtypes=[torch.float32]) @optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info): def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):