Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437
- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
- Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
- Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
- Remove this for the public `register_state_dict_post_hook`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704
Approved by: https://github.com/albanD
ghstack dependencies: #126906
Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict`
Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121158
Approved by: https://github.com/albanD
ghstack dependencies: #121157
Always preserve requires_grad of param in module. Documentation fixed in PR stacked above.
Also fix test case to test load a state_dict generated with `keep_vars=False` (the default)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121157
Approved by: https://github.com/albanD
Added a `torch.Tensor` method that defines how to transform `other`, a value in the state dictionary, to be loaded into `self`, a param/buffer in an `nn.Module` before swapping via `torch.utils.swap_tensors`
* `param.module_load(sd[key])`
This method can be overridden using `__torch_function__`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117913
Approved by: https://github.com/albanD