pytorch/torch/distributed/algorithms/_checkpoint
Mikayla Gawarecki c38b3381a1 Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437

- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
   - Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
    ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
 - Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
       - Remove this for the public `register_state_dict_post_hook`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704
Approved by: https://github.com/albanD
ghstack dependencies: #126906
2024-06-10 21:50:17 +00:00
..
__init__.py
checkpoint_wrapper.py Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704) 2024-06-10 21:50:17 +00:00