Commit Graph

9 Commits

Author SHA1 Message Date
PyTorch MergeBot
1d233b8f50 Revert "Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)"
This reverts commit c38b3381a1.

Reverted https://github.com/pytorch/pytorch/pull/126704 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193))
2024-06-11 17:45:20 +00:00
Mikayla Gawarecki
c38b3381a1 Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437

- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
   - Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
    ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
 - Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
       - Remove this for the public `register_state_dict_post_hook`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704
Approved by: https://github.com/albanD
ghstack dependencies: #126906
2024-06-10 21:50:17 +00:00
Mikayla Gawarecki
a2d4fea872 [easy] Move state_dict hooks tests to test_module_hooks and decorate tests that call load_state_dict with swap (#126906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126906
Approved by: https://github.com/albanD
2024-06-10 21:50:17 +00:00
FFFrog
d6f88105ce Fix the problem about load_state_dict with unexpected key whose prefix matches a valid key (#124385)
Fixes https://github.com/pytorch/pytorch/issues/123510

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124385
Approved by: https://github.com/mikaylagawarecki
2024-04-20 23:19:25 +00:00
Yuanhao Ji
a625705290 Enable UFMT on all of test/nn (#123809)
Part of: #123062

Ran lintrunner on:

- `test/nn`

with command:

```bash
lintrunner -a --take UFMT --all-files
```

Co-authored-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123809
Approved by: https://github.com/mikaylagawarecki
2024-04-12 18:32:25 +00:00
Mikayla Gawarecki
4b3903379a Add assign argument to torch.Tensor.module_load (#121158)
Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict`

Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121158
Approved by: https://github.com/albanD
ghstack dependencies: #121157
2024-03-06 01:32:06 +00:00
Mikayla Gawarecki
27389e03f0 [easy] Fixed requires_grad preservation for nn.Module.load_state_dict(assign=True) (#121157)
Always preserve requires_grad of param in module. Documentation fixed in PR stacked above.
Also fix test case to test load a state_dict generated with `keep_vars=False` (the default)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121157
Approved by: https://github.com/albanD
2024-03-06 01:32:06 +00:00
Mikayla Gawarecki
3372aa51b4 Integrate swap_tensors into nn.Module.load_state_dict (#117913)
Added a `torch.Tensor` method that defines how to transform `other`, a value in the state dictionary, to be loaded into `self`, a param/buffer in an `nn.Module` before swapping via `torch.utils.swap_tensors`
* `param.module_load(sd[key])`

This method can be overridden using `__torch_function__`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117913
Approved by: https://github.com/albanD
2024-02-09 22:32:29 +00:00
Mikayla Gawarecki
b92819a039 Move nn.Module.load_state_dict tests from test_nn.py to separate file (#118028)
Move these tests out so in https://github.com/pytorch/pytorch/pull/117913 where we can to run these tests with both `torch.nn.utils.set_swap_module_params_on_conversion({True/False})`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118028
Approved by: https://github.com/albanD
2024-02-05 20:17:28 +00:00