Reland https://github.com/pytorch/pytorch/pull/126704
#### Fixes the issue with type of `nn.Module._state_dict_hooks` being changed in that PR which was problematic:
Instead of using `Tuple(Callable, bool)` to keep track of whether the private `_register_state_dict_hook` or the public `register_state_dict_post_hook` API was used to register the hook and toggle the behavior accordingly, I set an attribute on the Callable in the private API, which is never cleaned up.
If a callable previously registered using the private API is registered via the public API, a RuntimeError will be raised
#### Copied from previous PR description
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437
- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
- Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
- For issuet by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
- Document this for private `_register_state_dict_hook`
- Remove this for the public `register_state_dict_post_hook`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131690
Approved by: https://github.com/albanD
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437
- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
- Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
- Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
- Remove this for the public `register_state_dict_post_hook`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704
Approved by: https://github.com/albanD
ghstack dependencies: #126906
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.
This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
This is for consistency with FSDP.
- `_FSDP_WRAPPED_MODULE` and `_CHECKPOINT_WRAPPED_MODULE` are exactly the wrapped module variable name, meaning you can call `getattr(module, _FSDP_WRAPPED_MODULE)` or `getattr(module, _CHECKPOINT_WRAPPED_MODULE)`.
- `_FSDP_PREFIX` and `_CHECKPOINT_PREFIX` include the trailing `"."` and are only used for FQNs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87951
Approved by: https://github.com/zhaojuanmao
We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.
Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.
This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.
Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
`_recursive_wrap()` returns `Tuple[nn.Module, int]`, where the `nn.Module` is the in-place modified module and the `int` is the numel wrapped. In that sense, the return value is not meant to be publicly used. The `apply_activation_checkpointing()` docs already suggest that the function returns `None`, so this PR simply follows that.
**Test Plan**
CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87871
Approved by: https://github.com/zhaojuanmao
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.
Now, offload to CPU + checkpoint can be composed together, such as
```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```
Will polish / add tests if this proposal sounds good.
Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:
```
with save_on_cpu:
checkpoint(module_forward())
```
which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.
Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.
These wrappers can be composed, i.e. if we have
`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`
we can do
`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`
and inner tensors would be checkpointed while outers will be offloaded.
Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.
Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70165
Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33228820
fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70164
Implement Alban's suggestion to make checkpoint_wrapper an nn.Module
instead of patching the forward pass, which is too hacky.
ghstack-source-id: 146011215
Test Plan: IC
Reviewed By: mrshenli
Differential Revision: D33214696
fbshipit-source-id: dc4b3e928d66fbde828ab60d90b314a8048ff7a2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955
Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.
Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.
Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D33107276
fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb