mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate test_state_dict hooks to OptimizerInfo (#119308)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119308 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #119283, #119288, #119299
This commit is contained in:
parent
5c46600f84
commit
0320e62255
|
|
@ -747,20 +747,6 @@ class TestOptim(TestCase):
|
|||
maximize=False,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
|
||||
optimizer.state["test"] = 1
|
||||
|
||||
@staticmethod
|
||||
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "test" in state_dict["state"]:
|
||||
state_dict["state"].pop("test")
|
||||
state_dict["ran_state_dict_pre_hook"] = True
|
||||
else:
|
||||
state_dict["ran_state_dict_pre_hook"] = False
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict["param_groups"][0]["lr"] = 0.002
|
||||
|
|
@ -778,32 +764,6 @@ class TestOptim(TestCase):
|
|||
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
|
||||
optimizer.state["ran_load_state_dict_post_hook"] = True
|
||||
|
||||
def test_state_dict_pre_hook(self):
|
||||
param = torch.rand(2, 3, requires_grad=True)
|
||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
||||
opt = SGD([param], lr=0.001)
|
||||
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
|
||||
state_dict = opt.state_dict()
|
||||
self.assertEqual(state_dict["state"]["test"], 1)
|
||||
|
||||
def test_state_dict_post_hook(self):
|
||||
param = torch.rand(2, 3, requires_grad=True)
|
||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
||||
opt = SGD([param], lr=0.001)
|
||||
opt.register_state_dict_post_hook(self._state_dict_post_hook)
|
||||
state_dict = opt.state_dict()
|
||||
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)
|
||||
|
||||
def test_state_dict_pre_post_hook(self):
|
||||
param = torch.rand(2, 3, requires_grad=True)
|
||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
||||
opt = SGD([param], lr=0.001)
|
||||
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
|
||||
opt.register_state_dict_post_hook(self._state_dict_post_hook)
|
||||
state_dict = opt.state_dict()
|
||||
self.assertFalse("test" in state_dict["state"])
|
||||
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)
|
||||
|
||||
def test_load_state_dict_pre_hook_and_prepend(self):
|
||||
param = torch.rand(2, 3, requires_grad=True)
|
||||
param.grad = torch.rand(2, 3, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -923,6 +923,71 @@ class TestOptimRenewed(TestCase):
|
|||
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
|
||||
optimizer.state["test"] = 1
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "test" in state_dict["state"]:
|
||||
state_dict["state"].pop("test")
|
||||
state_dict["ran_state_dict_pre_hook"] = True
|
||||
else:
|
||||
state_dict["ran_state_dict_pre_hook"] = False
|
||||
return state_dict
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_state_dict_pre_hook(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||
optim = optim_cls([param], **optim_input.kwargs)
|
||||
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
|
||||
state_dict = optim.state_dict()
|
||||
self.assertEqual(state_dict["state"]["test"], 1)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_state_dict_post_hook(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||
optim = optim_cls([param], **optim_input.kwargs)
|
||||
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
|
||||
state_dict = optim.state_dict()
|
||||
self.assertFalse(state_dict["ran_state_dict_pre_hook"])
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
|
||||
optim = optim_cls([param], **optim_input.kwargs)
|
||||
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
|
||||
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
|
||||
state_dict = optim.state_dict()
|
||||
self.assertFalse("test" in state_dict["state"])
|
||||
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_post_hook(self, device, dtype, optim_info):
|
||||
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user