mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reland https://github.com/pytorch/pytorch/pull/126704 #### Fixes the issue with type of `nn.Module._state_dict_hooks` being changed in that PR which was problematic: Instead of using `Tuple(Callable, bool)` to keep track of whether the private `_register_state_dict_hook` or the public `register_state_dict_post_hook` API was used to register the hook and toggle the behavior accordingly, I set an attribute on the Callable in the private API, which is never cleaned up. If a callable previously registered using the private API is registered via the public API, a RuntimeError will be raised #### Copied from previous PR description Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437 - `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook` - Add a test as this API was previously untested - `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True` ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~ - For issuet by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook - Document this for private `_register_state_dict_hook` - Remove this for the public `register_state_dict_post_hook` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131690 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| checkpoint_wrapper.py | ||