Commit Graph

232 Commits

Author SHA1 Message Date
Will Feng
5a2be192d1 [Traceable FSDP2] Don't register RegisterPostBackwardFunction if user intends to use Traceable FSDP2, and assert that compiled autograd is not used when entering RegisterPostBackwardFunction (#135824)
During enablement of Traceable FSDP2 on internal models, sometimes the user only applies torch.compile to some of the FSDP2 instances but not all of them. Such mixed usage pattern is not supported by compiled autograd. Here we try to catch and throw error at such usage pattern, so that the user can fix the usage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135824
Approved by: https://github.com/awgu
2024-09-14 06:30:12 +00:00
Wei Feng
d270e2d240 [FSDP2] better error msg for cpu offloading (#135156)
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135156
Approved by: https://github.com/awgu
2024-09-12 00:05:07 +00:00
Will Feng
94d2471d1f [Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730)
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).

This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.

------

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
2024-09-11 23:01:05 +00:00
Andrew Gu
c932b39739 [FSDP2] Added _set_unshard_async_op (#135523)
This PR adds a private API `_set_unshard_async_op` that allows for running pre-forward and pre-backward all-gathers using the `async_op=True` path so that all-gather allocations happen in the default stream to avoid inter-stream fragmentation.

If using this option, forward requires explicit prefetching e.g. via the `unshard(async_op=True)` API for overlap. fp32 -> bf16 casts and the all-gather copy-in will not overlap with compute.

Differential Revision: [D62401551](https://our.internmc.facebook.com/intern/diff/D62401551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135523
Approved by: https://github.com/weifengpy
2024-09-10 19:28:02 +00:00
Wanchao Liang
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
Will Feng
8fb1281db9 [Traceable FSDP2] Skip _backward_prefetch under compile, and rely on compiler pass to have prefetching (#135163)
Before this PR, when traceable FSDP2 + AC is run, an error would be thrown:
```
  File "/data/users/willfeng/pytorch/torch/_dynamo/variables/builtin.py", line 1449, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 435, in call_method
    return super().call_method(tx, name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 392, in call_method
    return super().call_method(tx, name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 131, in call_method
    return self.getitem_const(tx, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 106, in getitem_const
    return self.items[index]
Error: Index out of bound

from user code:
   File "<eval_with_key>.5", line 105, in forward
    aot0_trace_wrapped = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(aot0_tangents_1, bw_state = aot0_primals_34);  aot0_tangents_1 = None
  File "/data/users/willfeng/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 74, in self_invoke
    return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
  File "/data/users/willfeng/pytorch/torch/_dynamo/external_utils.py", line 132, in call_hook_from_backward_state
    return getattr(bw_state, hook_name)(*args, **kwargs)
  File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py", line 271, in _pre_backward
    self._fsdp_param_group.pre_backward(default_prefetch)
  File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 332, in pre_backward
    self._backward_prefetch()
  File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 417, in _backward_prefetch
    target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
```

Since it's okay to rely on the compiler to recover the "prefetching" pattern, we will skip this `_backward_prefetch()` code path during tracing to avoid the error, and have a compiler pass (in future PR) to achieve the equivalent prefetching overlap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135163
Approved by: https://github.com/awgu
2024-09-05 03:32:04 +00:00
CK Luk
ffd1e214df Back out "[FSDP2] Set ctx.set_materialize_grads(False) for post-backward (#133498)" (#135059)
Summary:
Original commit changeset: 96513cbc425f

Original Phabricator Diff: D61291210

There is some evidence that FB-FM-v4 has better NE with Set ctx.set_materialize_grads(False), especially when pairing up with prefetching.

See https://www.internalfb.com/intern/anp/view/?id=5732259

Test Plan:
export NUM_WORKERS=128
export BATCH_SIZE=1024
export CONFIG_FILE="mast_joint_arch_exploration_cmf_updated_fbfm_v3_fsdp2.yaml"

export ENTITLEMENT=ads_global_tc_2k_training_large_short
buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -c fbcode.platform010_cuda_version=12 -c hpc_comms.use_nccl=2.17.1 -- mode=${CONFIG_FILE} launcher.tags='[ads_ranking_taxonomy_monetization_genai]' launcher.data_project=pytorch_at_scale launcher.max_retries=10 launcher.fbl_entitl
ement=${ENTITLEMENT} launcher.oncall=pytorch_training_enablement launcher.hardware=GRANDTETON launcher.num_workers=${NUM_WORKERS} data_loader.dataset.batch_size=${BATCH_SIZE} training.planner.proposer=dynamic_col_dim training.planner.proposer.optim_target=h
bm 2>&1| tee ~/tmp/log.mast

Differential Revision: D62009163

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135059
Approved by: https://github.com/awgu
2024-09-04 04:50:32 +00:00
Will Feng
578b8d75e5 [2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539)
The previous PR https://github.com/pytorch/pytorch/pull/133532 caused stuck compilation issue on internal models. In this 2nd attempt PR, we gate the trace_rules.py changes with `if not torch._dynamo.config.skip_fsdp_hooks:`, so that they don't take effect for current graph-break FSDP2 (which relies on the default config value `skip_fsdp_hooks=True`), and will only take effect when we are using Traceable FSDP2 (in which case the user needs to proactively set `skip_fsdp_hooks=False`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134539
Approved by: https://github.com/ckluk2, https://github.com/yanboliang
2024-08-29 06:28:16 +00:00
PyTorch MergeBot
25531eb735 Revert "[2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539)"
This reverts commit 26e392132d.

Reverted https://github.com/pytorch/pytorch/pull/134539 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134539#issuecomment-2316568257))
2024-08-29 01:59:02 +00:00
Will Feng
26e392132d [2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539)
The previous PR https://github.com/pytorch/pytorch/pull/133532 caused stuck compilation issue on internal models. In this 2nd attempt PR, we gate the trace_rules.py changes with `if not torch._dynamo.config.skip_fsdp_hooks:`, so that they don't take effect for current graph-break FSDP2 (which relies on the default config value `skip_fsdp_hooks=True`), and will only take effect when we are using Traceable FSDP2 (in which case the user needs to proactively set `skip_fsdp_hooks=False`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134539
Approved by: https://github.com/ckluk2, https://github.com/yanboliang
2024-08-28 08:57:56 +00:00
CK Luk
43bbd781f2 Back out "[Traceable FSDPS] Allow tracing through FSDP2 impl in trace_rules.py (#133532)" (#134478)
Summary:
Original commit changeset: 0215a41433e9

Original Phabricator Diff: D61432583

D61432583 causes FSDP2 stuck in PT2 compilation when applied to FB-FM-v4.

With D61432583:
https://www.internalfb.com/mast/job/aps-ckluk-745e763d6a

After backing out D61432583:
https://www.internalfb.com/mast/job/aps-ckluk-f9604ea1f9

Test Plan:
hg graft D61774888
scripts/ckluk/aps/mast_joint_arch_exploration_cmf_updated_fbfm_v3_fsdp2_qps.sh

Differential Revision: D61802689

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134478
Approved by: https://github.com/yf225
2024-08-27 00:07:28 +00:00
Will Feng
6bddfb9546 [FSDP2] Add cache for FSDP wrapper class (#134135)
Currently, `fully_shard` will create a new `FSDPMyModuleClass` class for each `MyModuleClass` module **object**, which causes Dynamo to guard-fail on every module object's type checking. This PR fixes the issue by caching and reusing previously created FSDP wrapper class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134135
Approved by: https://github.com/awgu
2024-08-22 00:41:30 +00:00
PyTorch MergeBot
35f36363ec Revert "[dtensor] move DTensor to public namespace (#133113)"
This reverts commit 2ee6b97464.

Reverted https://github.com/pytorch/pytorch/pull/133113 on behalf of https://github.com/wanchaol due to looks like it break some internal type imports ([comment](https://github.com/pytorch/pytorch/pull/133113#issuecomment-2295670911))
2024-08-19 05:00:19 +00:00
Wanchao Liang
2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00
Will Feng
1df1d00ffc [Traceable FSDP2] Remove usage of tuple() generator and simplify code (#133636)
Dynamo doesn't support `tuple()` generator, and this change also simplifies code a bit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133636
Approved by: https://github.com/awgu
ghstack dependencies: #133532, #133531
2024-08-16 17:47:28 +00:00
Will Feng
6d85077168 [Traceable FSDPS] Allow tracing through FSDP2 impl in trace_rules.py (#133532)
Test commands:
- `python test/distributed/_composable/fsdp/test_fully_shard_training.py TestFullyShard1DTrainingCompose.test_train_parity_with_activation_checkpointing`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133532
Approved by: https://github.com/yanboliang
2024-08-16 17:13:47 +00:00
Andrew Gu
86aa327e4a [FSDP2] Added eager fast-path for fp32->bf16 param cast (#133369)
Some recommendation models have a high number of `nn.Parameter`s. This exacerbates per-tensor CPU overheads in FSDP2 compared to FSDP1.

This PR adds a fast-path for the common bf16/fp32 mixed precision case for the casting the parameters from fp32 to bf16 to reduce CPU overhead and possibly have more efficient copy.
- Old: `for` loop + `.to(torch.bfloat16)`, incurring dispatcher overhead per parameter
- New: `torch.empty` + `torch.split` + `torch._foreach_copy_`, incurring three dispatches

---

Example on Llama3-8B which does not have many `nn.Parameter`s (compared to recommendation models):

(Old) on Llama3-8B (0.46 ms CPU overhead for all-gather):
![Screenshot 2024-08-13 at 6 19 39 PM](https://github.com/user-attachments/assets/e6390e9f-ee54-4208-9d60-9451a4142efa)

(New) on Llama3-8B (0.37 ms CPU overhead for all-gather):
![Screenshot 2024-08-13 at 6 20 32 PM](https://github.com/user-attachments/assets/a5dc1d38-53d2-4984-b3cc-85ce5a538ede)

---

Same example as above but now with float8 all-gather:

(Old) on Llama3-8B with float8 (0.996 ms CPU overhead for all-gather):
![Screenshot 2024-08-15 at 11 27 46 AM](https://github.com/user-attachments/assets/2b7e9c9c-56ea-4375-851e-a2a704689d8d)

(New) on Llama3-8B with float8 (1.014 ms CPU overhead for all-gather):
![Screenshot 2024-08-15 at 11 26 33 AM](https://github.com/user-attachments/assets/160cf8f6-bb97-4633-b802-baeae74e3262)

The times are relatively comparable for float8 with the new one possibly slightly slower, but this is mainly because for Llama's transformer blocks, there are only two norm weights that need to cast to bf16. These screenshots are mainly to show that the optimization still works in the mixed case.

Differential Revision: [D61236983](https://our.internmc.facebook.com/intern/diff/D61236983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133369
Approved by: https://github.com/weifengpy
ghstack dependencies: #133498
2024-08-15 22:27:20 +00:00
Andrew Gu
59e33cd1f4 [FSDP2] Set ctx.set_materialize_grads(False) for post-backward (#133498)
https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.set_materialize_grads.html
This avoids unnecessarily `aten::zeros` for the inputs in the post-backward custom autograd backward. We do not need the gradient values for the post-backward logic.

Differential Revision: [D61291210](https://our.internmc.facebook.com/intern/diff/D61291210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133498
Approved by: https://github.com/weifengpy
2024-08-15 14:58:26 +00:00
Chien-Chin Huang
d114fd78bd [FSDP2] Enable HSDP + TP (#133335)
This PR enables HSDP + TP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133335
Approved by: https://github.com/awgu
2024-08-14 16:34:04 +00:00
Chien-Chin Huang
d53dfa4680 [BE] Raise when the target model has scalar parameters (#132934)
Address the issue, https://github.com/pytorch/pytorch/issues/130810.

Both FSDP1 and FSDP2 do not support scalar parameters. For FSDP1, the issue happens during state_dict operations while FSDP2 fails during the initialization. This PR adds exceptions to help users debug the issue and change the scalar parameters to 1D parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132934
Approved by: https://github.com/awgu, https://github.com/wz337
2024-08-12 18:28:02 +00:00
Wei Feng
cd307fb0b1 [FSDP2] reset FSDPParam.sharded_param in lazy_init (#132954)
motivated by FSDP2 + DoRA https://github.com/pytorch/pytorch/issues/132721

after meta init, we need a user-defined function to move DoRALinear.magnitude from device=meta to device=cuda
The problem is how to trigger reset_sharded_param or _apply to update FSDPParam. Otherwise lazy_init complains that DoRALinear.magnitude are still on device=meta

credit to @awgu for chasing after DDP lazy_init to unblock the PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132954
Approved by: https://github.com/awgu
ghstack dependencies: #133059
2024-08-09 20:26:10 +00:00
Wei Feng
472b0daeaa [DDP][FSDP2] keep DTensor params for replicate(fully_shard) (#133059)
current status: for `replicate(fully_shard)`, DDP lazy_init will convert DTensor into local tensor, and that breaks FSDP unshard

this PR keeps FSDP params untouched during DDP lazy_init
I came across it because of a CI error in FSDP2's unit test #132978
thanks @awgu for fix proposal

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133059
Approved by: https://github.com/Skylion007, https://github.com/fegin
2024-08-09 18:38:05 +00:00
PyTorch MergeBot
50595ecef4 Revert "[BE] Raise when the target model has scalar parameters (#132934)"
This reverts commit ea00036841.

Reverted https://github.com/pytorch/pytorch/pull/132934 on behalf of https://github.com/clee2000 due to I think this broke distributed/_composable/fsdp/test_fully_shard_init.py::TestFullyShardShardedParameterTensor::test_raise_scalar_parameter [GH job link](https://github.com/pytorch/pytorch/actions/runs/10314920655/job/28563430905) [HUD commit link](ea00036841).  Dr CI is wrong, it is not flaky ([comment](https://github.com/pytorch/pytorch/pull/132934#issuecomment-2278208789))
2024-08-09 15:30:34 +00:00
Chien-Chin Huang
ea00036841 [BE] Raise when the target model has scalar parameters (#132934)
Address the issue, https://github.com/pytorch/pytorch/issues/130810.

Both FSDP1 and FSDP2 do not support scalar parameters. For FSDP1, the issue happens during state_dict operations while FSDP2 fails during the initialization. This PR adds exceptions to help users debug the issue and change the scalar parameters to 1D parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132934
Approved by: https://github.com/awgu
ghstack dependencies: #132908, #132933
2024-08-09 06:45:48 +00:00
wz337
0ff0bf3d31 [Replicate] Fix replicate with DeviceMesh initialization (#133024)
A follow up on https://github.com/pytorch/pytorch/pull/132339.

`get_parent_mesh` is replaced by `get_root_mesh`. In addition, modify a few places that parent mesh is mentioned in test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133024
Approved by: https://github.com/Skylion007, https://github.com/fegin
2024-08-09 00:45:47 +00:00
Xilun Wu
322c9d03a0 [FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result (#130760)
Fixes issue #129229 #129206
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130760
Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/wz337
ghstack dependencies: #126697, #130239, #132391, #131408
2024-08-08 18:15:29 +00:00
Xilun Wu
40ce0a53bb [FSDP][dtensor] add FSDP2+TP distributed state dict test (#131408)
**Test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131408
Approved by: https://github.com/fegin
ghstack dependencies: #126697, #130239, #132391
2024-08-07 18:17:12 +00:00
wz337
87053132ea [DeviceMesh] Remove parent mesh concept from _MeshEnv and replace by root mesh (#132339)
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```

We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:

```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
2024-08-07 07:01:12 +00:00
PyTorch MergeBot
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7ce.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
Brian Hirsh
af8b8a47cb fsdp.set_: convey to functionalization that it mutates storage (#132322)
Fixes https://github.com/pytorch/pytorch/issues/132197

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132322
Approved by: https://github.com/albanD, https://github.com/yf225
ghstack dependencies: #132243, #132337
2024-08-05 21:28:59 +00:00
Andrew Gu
b30d0916d9 [FSDP2] Added missing event wait (for future) (#132568)
Nothing is actually wrong currently, but we should add this in case we land https://github.com/pytorch/pytorch/pull/127032 in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132568
Approved by: https://github.com/weifengpy, https://github.com/Skylion007
2024-08-05 12:44:46 +00:00
Xuehai Pan
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
Xuehai Pan
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
Wei Feng
bc7ed1fbdc [FSDP2] add __repr__ to FSDPParamGroup and FSDPParam (#132350)
in pdb, it's pretty common to print `FSDPParamGroup` and `FSDPParam`. making sure they are human readable

print `FSDPParam` in pdb
```
FSDPParam(fqn=layers.6._checkpoint_wrapped_module.attention.wq.weight, orig_size=torch.Size([128, 256]))
```
print `FSDPParamGroup` in pdb
```
FSDPParamGroup(fqn=layers.6)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132350
Approved by: https://github.com/awgu
2024-08-01 04:21:57 +00:00
PyTorch MergeBot
a3ba405871 Revert "[BE] typing for decorators - library (#131570)"
This reverts commit 5731b486c8.

Reverted https://github.com/pytorch/pytorch/pull/131570 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Will Feng
236d055330 [Traceable FSDP2] Add partial-graph (graph-break) unit tests (#131747)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131747
Approved by: https://github.com/bdhirsh
2024-07-26 02:51:57 +00:00
Aaron Orenstein
5731b486c8 [BE] typing for decorators - library (#131570)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131570
Approved by: https://github.com/oulgen, https://github.com/zou3519
ghstack dependencies: #131568, #131569
2024-07-25 22:24:19 +00:00
Andrew Gu
bc938184de [FSDP2] Added set_reduce_scatter_divide_factor (#129286)
This PR adds an API `FSDPModule.set_reduce_scatter_divide_factor` to allow setting a custom gradient divide factor for reduce-scatter. This can be useful when using parallelisms in combination with FSDP (e.g. expert parallelism), where gradients need to be divided by a custom factor (e.g. an extra `EP` factor).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129286
Approved by: https://github.com/weifengpy
2024-07-24 12:42:35 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Andrew Gu
23ae6e2eb3 [FSDP2] Removed state dict error for HSDP (#131320)
Fixes https://github.com/pytorch/torchtitan/issues/441#issuecomment-2241288906.

This PR avoids raising the 2D state dict error for HSDP, which does not depend on strided sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131320
Approved by: https://github.com/wanchaol, https://github.com/weifengpy
2024-07-22 19:23:17 +00:00
Will Feng
d77af49380 [Traceable FSDP2] Preserve fsdp.set_ op through lowering; Add unit test for multiple .set_ into same primal; Add unit test for FSDP2 module layer reuse (#130786)
Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_fullgraph_backend_inductor`
- `pytest -rA test/functorch/test_aotdispatch.py::TestAOTAutograd::test_input_mutation_fsdp_set__into_same_input`
- `PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py -k TestAOTAutogradWithCache.test_input_mutation_fsdp_set__into_same_input`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130786
Approved by: https://github.com/bdhirsh
ghstack dependencies: #129773
2024-07-17 23:25:42 +00:00
Andrew Gu
31e3330040 [Reland][FSDP2] Allowed List[nn.Module] as arg (#130949)
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: https://github.com/pytorch/torchtitan/pull/382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

---

**Changes for reland:** none since breakage was from PR below

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130949
Approved by: https://github.com/weifengpy
ghstack dependencies: #130947
2024-07-17 22:40:14 +00:00
Andrew Gu
ff7e021e94 [Reland][PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773) (#130947)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

---

**Changes for reland:**
- The previous PR assumed that any `func` decorated with `@contract` would return the same input `module` as output (which is true for PT-D composable APIs).
- However, TorchRec `shard` returns a different module as output (though that module _does_ satisfy the `@contract` FQN check).
- This PR removes the assumption and instead only enforces the FQN check following the input module order. In other words, if calling `func([x1, ..., xN])` for `N` modules `x1, ..., xN` that returns `[y1, ..., yM]` for `M` modules, we require that `N = M` and that FQNs are preserved coordinate-wise: `xi` and `yi` have same FQNs for all `i = 1, ..., N`.

Differential Revision: [D59863438](https://our.internmc.facebook.com/intern/diff/D59863438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130947
Approved by: https://github.com/weifengpy, https://github.com/atalman
2024-07-17 22:40:13 +00:00
PyTorch MergeBot
d7a8e8f7c5 Revert "[PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)"
This reverts commit b27695791e.

Reverted https://github.com/pytorch/pytorch/pull/127773 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/127773#issuecomment-2232004006))
2024-07-16 23:48:09 +00:00
PyTorch MergeBot
4f40a7078e Revert "[FSDP2] Allowed List[nn.Module] as arg (#127786)"
This reverts commit d3ab8ceced.

Reverted https://github.com/pytorch/pytorch/pull/127786 on behalf of https://github.com/atalman due to bottom pr from the stack is failing on internal error ([comment](https://github.com/pytorch/pytorch/pull/127786#issuecomment-2231999178))
2024-07-16 23:45:17 +00:00
Andrew Gu
d3ab8ceced [FSDP2] Allowed List[nn.Module] as arg (#127786)
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: https://github.com/pytorch/torchtitan/pull/382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127786
Approved by: https://github.com/yf225, https://github.com/weifengpy
ghstack dependencies: #127773
2024-07-15 23:54:10 +00:00
Andrew Gu
b27695791e [PT-D] Relaxed contract to allow Sequence[nn.Module] (#127773)
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127773
Approved by: https://github.com/weifengpy
2024-07-15 23:54:10 +00:00
Aaron Orenstein
567482973d typing fake_tensor.py (#128041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041
Approved by: https://github.com/eellison
ghstack dependencies: #129182
2024-07-13 06:07:40 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00