Internal only: the before way meant that `from torch.distributed._composable.fsdp import fully_shard` was importing `fully_shard.py` not the function `fully_shard`. For some reason, the resolution order is different from open source.
To fix this, we match the old import as closely as possible. Namely, we import `fully_shard.py` contents from `.fully_shard`. This should force that import to take precedence.
@diff-train-skip-merge
Differential Revision: [D66990327](https://our.internmc.facebook.com/intern/diff/D66990327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142419
Approved by: https://github.com/weifengpy
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Changes for Reland**
- Preserved the public objects from `torch/distributed/_composable/fsdp/fully_shard.py` so that the import path still works internally
- Added a unit test that we can do `from torch.distributed._composable.fsdp.fully_shard import FSDPModule`
Differential Revision: [D66890387](https://our.internmc.facebook.com/intern/diff/D66890387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy, https://github.com/fegin, https://github.com/XilunWu
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Follow-Ups**
- [x] Add some explanation in the docs about FSDP1 vs. FSDP2
- [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
FSDP1's `fully_shard` frontend was an exploration at the end of 2022 H2 as part of the `torch/distributed/_composable` APIs to avoid `nn.Module` wrappers. It calls into the same backend code as FSDP1's `FullyShardedDataParallel`.
The API did not gain traction internally, so we instead reused the name `fully_shard` for FSDP2, which similarly is not an `nn.Module` wrapper and follows similar design principles as FSDP1's `fully_shard`.
To the best of our knowledge, we have removed all instances of FSDP1's `fully_shard` internally, and we put the deprecation warning in open source in 2.4 saying it will be removed after 2.5 (which is now):
4959784dac/torch/distributed/_composable/fully_shard.py (L40-L48)
We are skipping the PR sanity check because this PR is only removing code, not adding new code, and should not require this sanity check.
Differential Revision: [D66664988](https://our.internmc.facebook.com/intern/diff/D66664988)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141875
Approved by: https://github.com/weifengpy
Assuming the forward pass user code looks like:
```
for _ in range(2):
x = layer(x)
```
and we have `fully_shard(layer)`, then:
- the forward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" (currently same for both eager and compile)
- the backward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" in eager, but currently it's "unshard layer -> call layer 1st time -> call layer 2nd time -> reshard layer" in compile
The behavior in the backward pass is different between eager and compile, which is not ideal.
I am currently trying to look for a way to fix this non-ideal behavior of compile - tried a few things:
1. Tracing the RegisterPostBackwardFunction custom autograd function - this stills seems to be a no-go, due to HOP not supporting side-effects.
2. Instead of custom autograd function, do a "multi-grad hook" to wait for all gradients to be ready before triggering post_backward. However, this approach seems to have bad interaction with register_hook of pre_backward, in the sense that it's unclear which of them will be triggered first in practice.
3. Force execute any pending post_backward before unshard in pre_backward hook, and rely on compiler to move the reshard to the right place to optimize peak memory. -> This PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139671
Approved by: https://github.com/awgu
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
If there is an `unshard` (top-half) without a `wait_for_unshard` (bottom-half), then the next iteration's `unshard` will be a no-op. This can unexpectedly not propagate the optimizer update on the sharded parameters to the unsharded parameters, so it is better to clear that `unshard` at the end of backward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137348
Approved by: https://github.com/weifengpy
currently FSDP2 support only CUDA, for other backends that need to use FSDP2 it won’t work as stream and events are based on CUDA. To support other backends, use
_get_device_handle by device type to get the class and use this
for stream and events.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136843
Approved by: https://github.com/awgu
This PR relaxes the even sharding requirement for the all-gather extensions.
The `fsdp_pre_all_gather` now expects signature:
```diff
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
+ outer_size: torch.Size,
+ outer_stride: Tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
```
- Since no one is using this new signature yet, we should be safe to change it.
- Currently, the `outer_stride` will always be contiguous strides since FSDP2 only supports contiguous strides for now.
- For the uneven sharding case, the user is responsible to return a padded sharded tensor from `fsdp_pre_all_gather`. This is risky territory because if the user does not do so, then this may manifest as a NCCL timeout, as only the ranks with padding will error out. However, I am not aware of any way around this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137005
Approved by: https://github.com/weifengpy
Since our implementation currently assumes contiguous strides, let us add an explicit check and raise an error at construction time if the parameter is not contiguous.
We can try to support this in the future. Mainly, I want to first learn more about how DTensor support for non-contiguous memory formats works.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137000
Approved by: https://github.com/weifengpy
- Sometimes having access to the `MixedPrecisionPolicy` in the `fsdp_pre_all_gather` is useful. See [here](https://github.com/pytorch/ao/pull/748/files#r1760375325) in the torchao INT8 mixed precision training PR.
- Sometimes having access to the owning `nn.Module` allows for using it for saving state. See [here](https://github.com/pytorch/pytorch/issues/114299#issuecomment-2298692762) for an example.
The major paint point here is how to deal with backward compatibility. For now, we use `signature.inspect` to check if the user subclass follows the old vs. new signature. However, for the new signature, the `param_dtype` in the post-all-gather is redundant, as if the user needed it, the user could save it from the `mp_policy` passed in the pre-all-gather now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136129
Approved by: https://github.com/weifengpy
During enablement of Traceable FSDP2 on internal models, sometimes the user only applies torch.compile to some of the FSDP2 instances but not all of them. Such mixed usage pattern is not supported by compiled autograd. Here we try to catch and throw error at such usage pattern, so that the user can fix the usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135824
Approved by: https://github.com/awgu
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```
this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading
```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```
`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135156
Approved by: https://github.com/awgu
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).
This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.
------
Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
This PR adds a private API `_set_unshard_async_op` that allows for running pre-forward and pre-backward all-gathers using the `async_op=True` path so that all-gather allocations happen in the default stream to avoid inter-stream fragmentation.
If using this option, forward requires explicit prefetching e.g. via the `unshard(async_op=True)` API for overlap. fp32 -> bf16 casts and the all-gather copy-in will not overlap with compute.
Differential Revision: [D62401551](https://our.internmc.facebook.com/intern/diff/D62401551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135523
Approved by: https://github.com/weifengpy
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Before this PR, when traceable FSDP2 + AC is run, an error would be thrown:
```
File "/data/users/willfeng/pytorch/torch/_dynamo/variables/builtin.py", line 1449, in call_getitem
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 435, in call_method
return super().call_method(tx, name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 392, in call_method
return super().call_method(tx, name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 131, in call_method
return self.getitem_const(tx, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 106, in getitem_const
return self.items[index]
Error: Index out of bound
from user code:
File "<eval_with_key>.5", line 105, in forward
aot0_trace_wrapped = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(aot0_tangents_1, bw_state = aot0_primals_34); aot0_tangents_1 = None
File "/data/users/willfeng/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 74, in self_invoke
return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
File "/data/users/willfeng/pytorch/torch/_dynamo/external_utils.py", line 132, in call_hook_from_backward_state
return getattr(bw_state, hook_name)(*args, **kwargs)
File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py", line 271, in _pre_backward
self._fsdp_param_group.pre_backward(default_prefetch)
File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 332, in pre_backward
self._backward_prefetch()
File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 417, in _backward_prefetch
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
```
Since it's okay to rely on the compiler to recover the "prefetching" pattern, we will skip this `_backward_prefetch()` code path during tracing to avoid the error, and have a compiler pass (in future PR) to achieve the equivalent prefetching overlap.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135163
Approved by: https://github.com/awgu