We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.
It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.
Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
**Summary**
As much of ReplicateState functionality is copied from FSDPState, I fixed any remaining comments that incorrectly used FSDP instead of Replicate. In addition, instead of labeling modules FSDPModule or FSDPLinear, I have changed it so that is now uses Replicate____. Finally, I have removed some leftover code from the DDP implementation. I have included test cases to verify correctness.
**Test Case**
1. pytest test/distributed/_composable/test_replicate_with_fsdp.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160133
Approved by: https://github.com/mori360
ghstack dependencies: #160128
**Summary**
Users would like to use Replicate with TP. Currently, the replicate function uses DDP, which has not been maintained resulting in a lack of integration options. Since users can use FSDP with TP, we will make the replicate function use FSDP so that users can use replicate with FSDP. To that end I have created a replicate function that uses FSDP instead of DDP. One blocker that I ran into is that the replicate function has a contract which assigns a module "replicate" attribute in registry. This would mean that fully_shards is_composable requirement would not be satisfied making it impossible to apply fully_shard to a replicate module. The solution to this was to copy the fully_shard function and state and modify it for replicate. In the future, it should be explored making the replicate_state inherit from FSDP_state to get rid of code duplicity. I have attached below the profile tracing of a replicated Net Module.
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/anshulsi_270fcc36-194a-42f5-9841-cace984c2132_devgpu263.prn2.facebook.com_1792146.1753232748025155780.pt.trace.json
**Test Case**
1. pytest test/distributed/_composable/test_replicate_with_fsdp.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158207
Approved by: https://github.com/weifengpy
Co-authored-by: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com>
Internal only: the before way meant that `from torch.distributed._composable.fsdp import fully_shard` was importing `fully_shard.py` not the function `fully_shard`. For some reason, the resolution order is different from open source.
To fix this, we match the old import as closely as possible. Namely, we import `fully_shard.py` contents from `.fully_shard`. This should force that import to take precedence.
@diff-train-skip-merge
Differential Revision: [D66990327](https://our.internmc.facebook.com/intern/diff/D66990327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142419
Approved by: https://github.com/weifengpy
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Changes for Reland**
- Preserved the public objects from `torch/distributed/_composable/fsdp/fully_shard.py` so that the import path still works internally
- Added a unit test that we can do `from torch.distributed._composable.fsdp.fully_shard import FSDPModule`
Differential Revision: [D66890387](https://our.internmc.facebook.com/intern/diff/D66890387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy, https://github.com/fegin, https://github.com/XilunWu
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Follow-Ups**
- [x] Add some explanation in the docs about FSDP1 vs. FSDP2
- [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
FSDP1's `fully_shard` frontend was an exploration at the end of 2022 H2 as part of the `torch/distributed/_composable` APIs to avoid `nn.Module` wrappers. It calls into the same backend code as FSDP1's `FullyShardedDataParallel`.
The API did not gain traction internally, so we instead reused the name `fully_shard` for FSDP2, which similarly is not an `nn.Module` wrapper and follows similar design principles as FSDP1's `fully_shard`.
To the best of our knowledge, we have removed all instances of FSDP1's `fully_shard` internally, and we put the deprecation warning in open source in 2.4 saying it will be removed after 2.5 (which is now):
4959784dac/torch/distributed/_composable/fully_shard.py (L40-L48)
We are skipping the PR sanity check because this PR is only removing code, not adding new code, and should not require this sanity check.
Differential Revision: [D66664988](https://our.internmc.facebook.com/intern/diff/D66664988)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141875
Approved by: https://github.com/weifengpy
Assuming the forward pass user code looks like:
```
for _ in range(2):
x = layer(x)
```
and we have `fully_shard(layer)`, then:
- the forward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" (currently same for both eager and compile)
- the backward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" in eager, but currently it's "unshard layer -> call layer 1st time -> call layer 2nd time -> reshard layer" in compile
The behavior in the backward pass is different between eager and compile, which is not ideal.
I am currently trying to look for a way to fix this non-ideal behavior of compile - tried a few things:
1. Tracing the RegisterPostBackwardFunction custom autograd function - this stills seems to be a no-go, due to HOP not supporting side-effects.
2. Instead of custom autograd function, do a "multi-grad hook" to wait for all gradients to be ready before triggering post_backward. However, this approach seems to have bad interaction with register_hook of pre_backward, in the sense that it's unclear which of them will be triggered first in practice.
3. Force execute any pending post_backward before unshard in pre_backward hook, and rely on compiler to move the reshard to the right place to optimize peak memory. -> This PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139671
Approved by: https://github.com/awgu
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593