Commit Graph

289 Commits

Author SHA1 Message Date
Maggie Moss
8f80892359 Use correct pyrefly syntax in suppressions distributed/... (#166241)
Updates the pyrefy-ignores in the torch/distributed directory to use the correct syntax. No functional changes.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166241
Approved by: https://github.com/oulgen
2025-10-26 04:16:41 +00:00
fduwjj
7406d2e665 [DeviceMesh] Clean up the call into mesh_resouces to get root mesh (#165787)
We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
2025-10-21 02:54:04 +00:00
Yuanyuan Chen
fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00
PyTorch MergeBot
b8be796a57 Revert "[2/N] More ruff SIM fixes (#165031)"
This reverts commit 38095fbd13.

Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
2025-10-10 13:42:14 +00:00
Yuanyuan Chen
38095fbd13 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-10 05:37:46 +00:00
Maggie Moss
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
PyTorch MergeBot
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
Yuanyuan Chen
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
Anshul Sinha
3dab36bdb4 [FSDP][Replicate] created ReplicateModule and changed replicate to use it instead of FSDPModule (#163897)
**Summary:** In order to minimize the code copied from FSDP to make replicate work, I made all replicated modules FSDPModule. While this was sufficient originally, there are changes to codebase like below that require us to differentiate between a FSDPModule and a ReplicateModule so that we can access replicate_state or fsdp_state: https://www.internalfb.com/code/fbsource/[a9a8e5102052]/fbcode/caffe2/torch/distributed/pipelining/stage.py?lines=629-666.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163897
Approved by: https://github.com/weifengpy
2025-10-01 17:30:10 +00:00
Yuanyuan Chen
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
soulitzer
1e4dfeeb06 Add early_stop kwarg to torch.utils.checkpoint (#160781)
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.

It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.

Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
2025-08-26 22:32:35 +00:00
Anshul Sinha
72009ec6be [replicate][be] improved readability and cleaned up remaining DDP code (#160133)
**Summary**
As much of ReplicateState functionality is copied from FSDPState, I fixed any remaining comments that incorrectly used FSDP instead of Replicate. In addition, instead of labeling modules FSDPModule or FSDPLinear, I have changed it so that is now uses Replicate____. Finally, I have removed some leftover code from the DDP implementation. I have included test cases to verify correctness.

**Test Case**
1. pytest test/distributed/_composable/test_replicate_with_fsdp.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160133
Approved by: https://github.com/mori360
ghstack dependencies: #160128
2025-08-08 19:42:23 +00:00
Anshul Sinha
5cdb3d896e [FSDP][Replicate] added replicate function that uses FSDP instead of DDP (#158207)
**Summary**
Users would like to use Replicate with TP. Currently, the replicate function uses DDP, which has not been maintained resulting in a lack of integration options. Since users can use FSDP with TP, we will make the replicate function use FSDP so that users can use replicate with FSDP. To that end I have created a replicate function that uses FSDP instead of DDP. One blocker that I ran into is that the replicate function has a contract which assigns a module "replicate" attribute in registry. This would mean that fully_shards is_composable requirement would not be satisfied making it impossible to apply fully_shard to a replicate module. The solution to this was to copy the fully_shard function and state and modify it for replicate. In the future, it should be explored making the replicate_state inherit from FSDP_state to get rid of code duplicity. I have attached below the profile tracing of a replicated Net Module.

https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/anshulsi_270fcc36-194a-42f5-9841-cace984c2132_devgpu263.prn2.facebook.com_1792146.1753232748025155780.pt.trace.json

**Test Case**
1.  pytest test/distributed/_composable/test_replicate_with_fsdp.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158207
Approved by: https://github.com/weifengpy

Co-authored-by: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com>
2025-07-23 22:53:06 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Shawn Xu
9da250aada type fully_shard so that the return value can be chained with typing enabled (#147489)
This allows for

```
fsdped = fully_shard(model)
fsdped.set_xyz()
```

same applies if `model` is actually a list of modules

Differential Revision: [D69888119](https://our.internmc.facebook.com/intern/diff/D69888119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147489
Approved by: https://github.com/Skylion007
ghstack dependencies: #147488
2025-02-20 08:43:16 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Shawn Xu
de1cb0f351 capture the return value in the contract typing (#147488)
----

* the existing typing makes the return type `Optional[nn.Module]`
* this doesn't seem to be what the decorator actually does as it does
  not alter the original return type
* This PR aims to fix the typing

Differential Revision: [D69888120](https://our.internmc.facebook.com/intern/diff/D69888120)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147488
Approved by: https://github.com/Skylion007
2025-02-20 03:32:34 +00:00
Aaron Orenstein
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
PyTorch MergeBot
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
bobrenjc93
fbad833538 Migrate from Tuple -> tuple in test/distributed/_composable (#144254)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144254
Approved by: https://github.com/aorenste
2025-01-10 06:38:05 +00:00
Aaron Orenstein
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
Andrew Gu
bd867d691b [FSDP2] Fix backward-compatible imports (#142419)
Internal only: the before way meant that `from torch.distributed._composable.fsdp import fully_shard` was importing `fully_shard.py` not the function `fully_shard`. For some reason, the resolution order is different from open source.

To fix this, we match the old import as closely as possible. Namely, we import `fully_shard.py` contents from `.fully_shard`. This should force that import to take precedence.

@diff-train-skip-merge

Differential Revision: [D66990327](https://our.internmc.facebook.com/intern/diff/D66990327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142419
Approved by: https://github.com/weifengpy
2024-12-09 23:56:32 +00:00
Andrew Gu
78425bff30 [FSDP2] Move to public torch.distributed.fsdp (#141868)
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.

**Changes for Reland**
- Preserved the public objects from `torch/distributed/_composable/fsdp/fully_shard.py` so that the import path still works internally
- Added a unit test that we can do `from torch.distributed._composable.fsdp.fully_shard import FSDPModule`

Differential Revision: [D66890387](https://our.internmc.facebook.com/intern/diff/D66890387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy, https://github.com/fegin, https://github.com/XilunWu

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
2024-12-07 01:24:28 +00:00
PyTorch MergeBot
bab15df40a Revert "[FSDP2] Move to public torch.distributed.fsdp (#141868)"
This reverts commit 45583a5df9.

Reverted https://github.com/pytorch/pytorch/pull/141868 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/141868#issuecomment-2523925180))
2024-12-06 18:38:12 +00:00
Andrew Gu
45583a5df9 [FSDP2] Move to public torch.distributed.fsdp (#141868)
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.

**Follow-Ups**
- [x] Add some explanation in the docs about FSDP1 vs. FSDP2
- [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
2024-12-05 03:04:01 +00:00
Andrew Gu
5c59f4a55a Remove old FSDP1 fully_shard (#141875)
FSDP1's `fully_shard` frontend was an exploration at the end of 2022 H2 as part of the `torch/distributed/_composable` APIs to avoid `nn.Module` wrappers. It calls into the same backend code as FSDP1's `FullyShardedDataParallel`.

The API did not gain traction internally, so we instead reused the name `fully_shard` for FSDP2, which similarly is not an `nn.Module` wrapper and follows similar design principles as FSDP1's `fully_shard`.

To the best of our knowledge, we have removed all instances of FSDP1's `fully_shard` internally, and we put the deprecation warning in open source in 2.4 saying it will be removed after 2.5 (which is now):
4959784dac/torch/distributed/_composable/fully_shard.py (L40-L48)

We are skipping the PR sanity check because this PR is only removing code, not adding new code, and should not require this sanity check.

Differential Revision: [D66664988](https://our.internmc.facebook.com/intern/diff/D66664988)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141875
Approved by: https://github.com/weifengpy
2024-12-03 17:00:47 +00:00
Edward Z. Yang
612122af8f Fix type-safety of torch.nn.Module instances (#141240)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141240
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-22 00:05:05 +00:00
wangyicheng
ee3a4f068c [FSDP2] privateuse1 support fsdp2. (#139539)
We are looking forward to supporting FSDP2 with devices other than CUDA. Please give me some coding suggestions. Thank you very much.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139539
Approved by: https://github.com/kwen2501
2024-11-15 06:34:35 +00:00
Andrew Gu
78a8f7f5c3 [FSDP2] Fix CUDA sync for bf16 HSDP AR, fp32 params (#140044)
Differential Revision: [D65621037](https://our.internmc.facebook.com/intern/diff/D65621037)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140044
Approved by: https://github.com/weifengpy
2024-11-12 13:31:40 +00:00
Andrew Gu
39ede99a33 Add current FSDP2 path to old composable FSDP1 warning (#139759)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139759
Approved by: https://github.com/weifengpy, https://github.com/wz337
ghstack dependencies: #139650
2024-11-06 01:43:04 +00:00
Will Feng
6a30c14a0a [Traceable FSDP2] Run any unexecuted post_backward at beginning of pre_backward hook (#139671)
Assuming the forward pass user code looks like:
```
for _ in range(2):
    x = layer(x)
```
and we have `fully_shard(layer)`, then:
- the forward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" (currently same for both eager and compile)
- the backward pass will be like: "unshard layer -> call layer 1st time -> reshard layer -> unshard layer -> call layer 2nd time-> reshard layer" in eager, but currently it's "unshard layer -> call layer 1st time -> call layer 2nd time -> reshard layer" in compile

The behavior in the backward pass is different between eager and compile, which is not ideal.

 I am currently trying to look for a way to fix this non-ideal behavior of compile - tried a few things:
1. Tracing the RegisterPostBackwardFunction custom autograd function - this stills seems to be a no-go, due to HOP not supporting side-effects.
2. Instead of custom autograd function, do a "multi-grad hook" to wait for all gradients to be ready before triggering post_backward. However, this approach seems to have bad interaction with register_hook of pre_backward, in the sense that it's unclear which of them will be triggered first in practice.
3. Force execute any pending post_backward before unshard in pre_backward hook, and rely on compiler to move the reshard to the right place to optimize peak memory. -> This PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139671
Approved by: https://github.com/awgu
2024-11-06 00:19:06 +00:00
Yuanhao Ji
e52ccb3ca6 [Device] Replace hardcoded devices with 'torch._C._get_accelerator()' (#139032)
I noticed that some hard-code like `"cuda" if torch.cuda.is_available() else "cpu"` which can be replaced with `torch._C._get_accelerator()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139032
Approved by: https://github.com/ezyang
2024-10-29 04:51:47 +00:00
Simon Fan
fd9f4e6770 Back out "[compiled autograd] tls access helpers (#138061)" and Back out "[compiled autograd] Compiled autograd configs in TLS (#137821)" (#139086)
Summary:
Original commit changeset: 9bf80c1492d7

Original Phabricator Diff: D64796226

Original commit changeset: aa1d9ef8f6e6

Original Phabricator Diff: D64796212

Differential Revision: D65072644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139086
Approved by: https://github.com/malfet
2024-10-28 23:37:05 +00:00
Simon Fan
5a13282c75 [compiled autograd] tls access helpers (#138061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138061
Approved by: https://github.com/yf225
ghstack dependencies: #137953, #137821
2024-10-22 08:03:52 +00:00
Simon Fan
49fa437097 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-22 08:03:52 +00:00
Will Feng
fcedf93d1e [Traceable FSDP2] Add _compiled_autograd_enabled global state variable (#138187)
After https://github.com/pytorch/pytorch/pull/137821, we will no longer be able to call the Compiled Autograd state getter under Dynamo tracing. One solution is to cache the "Compiled Autograd enabled" state outside of compile for FSDP2, and just read from the cache when we need the check. This is implemented by this PR.

Fixes https://github.com/pytorch/pytorch/issues/138177.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138187
Approved by: https://github.com/xmfan, https://github.com/awgu
2024-10-19 19:10:31 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
PyTorch MergeBot
795255a7c8 Revert "[Traceable FSDP2] Add _compiled_autograd_enabled global state variable (#138187)"
This reverts commit 0c913b35aa.

Reverted https://github.com/pytorch/pytorch/pull/138187 on behalf of https://github.com/yf225 due to linux-focal-rocm6.2-py3.10 / test (distributed, 1, 3, linux.rocm.gpu) test_compiled_autograd_ctx failed ([comment](https://github.com/pytorch/pytorch/pull/138187#issuecomment-2423609108))
2024-10-19 06:12:47 +00:00
Will Feng
0c913b35aa [Traceable FSDP2] Add _compiled_autograd_enabled global state variable (#138187)
After https://github.com/pytorch/pytorch/pull/137821, we will no longer be able to call the Compiled Autograd state getter under Dynamo tracing. One solution is to cache the "Compiled Autograd enabled" state outside of compile for FSDP2, and just read from the cache when we need the check. This is implemented by this PR.

Fixes https://github.com/pytorch/pytorch/issues/138177.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138187
Approved by: https://github.com/xmfan, https://github.com/awgu
ghstack dependencies: #138245, #138174
2024-10-19 04:33:35 +00:00
Will Feng
504904c9c6 [Traceable FSDP2] Add compiled_autograd_enabled helper function (#138105)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138105
Approved by: https://github.com/awgu, https://github.com/xmfan
2024-10-17 00:04:06 +00:00
PyTorch MergeBot
361f42bc42 Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)"
This reverts commit 9aba0b91c8.

Reverted https://github.com/pytorch/pytorch/pull/137821 on behalf of https://github.com/wdvr due to Reverting this for now, it is failing test_public_bindings in trunk ([comment](https://github.com/pytorch/pytorch/pull/137821#issuecomment-2417351788))
2024-10-16 16:38:29 +00:00
Simon Fan
9aba0b91c8 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-16 09:28:32 +00:00
Andrew Gu
3cc8c8b944 [FSDP2] Add set_unshard_in_backward(bool) (#137922)
For some expert use cases, the user knows some parameters are not required for backward, so we can skip the unshard in backward. One example is the embedding weight.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137922
Approved by: https://github.com/weifengpy
2024-10-15 19:11:14 +00:00
Andrew Gu
5835b1af10 [FSDP2] Gated dynamo import for torch deploy (#137203)
Differential Revision: [D63777335](https://our.internmc.facebook.com/intern/diff/D63777335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137203
Approved by: https://github.com/wz337
2024-10-11 16:38:19 +00:00
Andrew Gu
a93ea617b5 [FSDP2] Required mesh_dim_names for HSDP (#137436)
Two changes:
1. Require `mesh_dim_names` if using HSDP
2. Pass only the shard mesh to `fsdp_pre_all_gather`

Change 1 is technically BC breaking, but it should not be hard to fix on the user side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137436
Approved by: https://github.com/weifengpy, https://github.com/wz337
2024-10-09 20:35:09 +00:00
Andrew Gu
aa61e251d4 [FSDP2] Added shard_placement_fn arg (#137496)
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
2024-10-09 19:13:32 +00:00
Andrew Gu
ceb2fcc5db [FSDP2] Fixed incorrect tensor meta after .to(dtype) (#137593)
This fixes https://github.com/pytorch/pytorch/issues/137522. After a method that changes to module parameters (like `.to(torch.float64)`), we need to update the `DTensorSpec`, whose `TensorMeta`'s dtype may have changed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137593
Approved by: https://github.com/Skylion007
2024-10-09 17:57:11 +00:00
PyTorch MergeBot
5e3e1c0151 Revert "[FSDP2] Required mesh_dim_names for HSDP (#137436)"
This reverts commit 5fb30df7d6.

Reverted https://github.com/pytorch/pytorch/pull/137436 on behalf of https://github.com/malfet due to Looks like it broke distributed testing, see https://github.com/pytorch/pytorch/actions/runs/11239761070/job/31249854217 ([comment](https://github.com/pytorch/pytorch/pull/137436#issuecomment-2400794929))
2024-10-08 20:50:49 +00:00
Andrew Gu
5fb30df7d6 [FSDP2] Required mesh_dim_names for HSDP (#137436)
Two changes:
1. Require `mesh_dim_names` if using HSDP
2. Pass only the shard mesh to `fsdp_pre_all_gather`

Change 1 is technically BC breaking, but it should not be hard to fix on the user side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137436
Approved by: https://github.com/weifengpy, https://github.com/wz337
2024-10-08 16:31:18 +00:00