Commit Graph

81 Commits

Author SHA1 Message Date
Matthew Hoffman
68b0db1274 Define the public API for torch.distributed.fsdp (#109922)
Related: https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
Related: https://github.com/microsoft/pylance-release/issues/2953

This fixes pylance issues for these classes:

```
"FullyShardedDataParallel" is not exported from module "torch.distributed.fsdp"
```

These classes all have public docs:

* [`BackwardPrefetch`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch)
* [`CPUOffload`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.CPUOffload)
* [`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel)
* [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
* [`ShardingStrategy`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy)

And it seems like all the newly added classes will have docs once they are released.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109922
Approved by: https://github.com/wanchaol
2023-09-28 02:15:58 +00:00
Andrew Gu
57fba6fd86 [FSDP][9/N] Introduce CustomPolicy (#104986)
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
2023-08-03 12:46:36 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Rohan Varma
43b3632215 [Composable] Add hybrid shard AC compile test (#105207)
This was request to ensure hybrid shard + AC + compile works.

Differential Revision: [D47462393](https://our.internmc.facebook.com/intern/diff/D47462393/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105207
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 21:03:55 +00:00
Rohan Varma
4137d6e499 [Composable FSDP] Enable HSDP (#105206)
Need to pass in strategy to _init_process_group_state to enable hsdp
for composable.

Differential Revision: [D47462394](https://our.internmc.facebook.com/intern/diff/D47462394/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105206
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 21:03:55 +00:00
Rohan Varma
03e2ca9d9c [Composable] Add more sharding strategies to runtime test (#105205)
Add more sharding strategies to ensure equivalence

Differential Revision: [D47462392](https://our.internmc.facebook.com/intern/diff/D47462392/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105205
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Rohan Varma
a326f5621e composable fsdp, checkpoint, + compile test (#105180)
Test to ensure that composable FSDP, checkpoint, and compile all work
together. Includes a change from https://github.com/pytorch/pytorch/pull/105090
which we can land in that PR first.

Differential Revision: [D47452973](https://our.internmc.facebook.com/intern/diff/D47452973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105180
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Rohan Varma
5d70fe0165 [Composable] Use non-reentrant generator, remove reentrant (#105176)
Removes reentrant support for the composable checkpoint, as
non-reentrant is the recommended approach and we should use this when rolling
out composable checkpoint API.

Also removes the standalone implementation for non-reentrant and instead uses
the generator from below diff to reuse the original implemenetation.

Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 07:03:03 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Chien-Chin Huang
a10f93f606 [composable API] Fix the replicate_device_id test case to avoid copy replicated models. (#105503)
We should not `replicate` deeocopy.copy(a already replicated model).

Differential Revision: [D47566678](https://our.internmc.facebook.com/intern/diff/D47566678/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105503
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2023-07-19 16:20:43 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Andrew Gu
6c1d959889 [FSDP] Annotate modules for fully_shard (#104363)
This annotates modules managed by `fully_shard` for TorchDynamo to treat them specially.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104363
Approved by: https://github.com/fegin
2023-07-06 16:56:59 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
2eea3cb19d Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
2023-06-14 21:53:30 +00:00
Rohan Varma
5b623d6c6a [Composable] fully_shard load_optim test (#102692)
Closes https://github.com/pytorch/pytorch/issues/93280 and adds tests
for this.

Differential Revision: [D46343364](https://our.internmc.facebook.com/intern/diff/D46343364/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102692
Approved by: https://github.com/awgu
2023-06-04 18:31:22 +00:00
Rohan Varma
0ecca122e7 [Replicate] Add unit test with replicate param names (#102401)
This attribute wasn't actually used in tests, add a test ensuring that
if replicate is used on top of FSDP, the replicated parameter names are as
expected.

TODO: there are a few ways to check if module is managed by composable API,
such as replicated param names for replicate, _get_module_state API,
_get_registry_api, etc. We should unify all composable APIs to check in a
unified way (filed an issue)

Differential Revision: [D46236377](https://our.internmc.facebook.com/intern/diff/D46236377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102401
Approved by: https://github.com/awgu
2023-05-31 18:41:03 +00:00
Yanli Zhao
5ac48eb353 [FSDP]Skip unshard call during checkpointing for NO_SHARD sharding strategy (#101095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101095
Approved by: https://github.com/fegin
2023-05-12 18:19:18 +00:00
Daniel Dale
2c29149109 Enhance Composable FSDP cast forward input tests (#100349)
The fix for https://github.com/pytorch/pytorch/pull/99545 (https://github.com/pytorch/pytorch/pull/99546) explicitly required users to set `cast_forward_inputs=False` if they wanted to avoid hitting #99545 while using an FSDP root module with no direct parameters.

After further consideration, [the team believes](https://github.com/pytorch/pytorch/pull/99546#discussion_r1180898687) it is sufficiently common for the default `cast_forward_inputs=False` to be used with a FSDP root module possessing no direct parameters that a solution to #99545 that accommodates this use case is desired.

This PR builds on @zhaojuanmao's https://github.com/pytorch/pytorch/pull/100290 (nice!) to enhance the FSDP cast forward inputs testing to include a broader range of scenarios and to include `model.eval()` testing as well as training mode validation. (I unfortunately don't have permissions that would allow me to use ghstack directly but I can rebase this PR however the team desires, once #100290 lands etc.)

Currently, the evaluation mode testing is commented out while the team decides on the best approach to implementing the broader solution to https://github.com/pytorch/pytorch/pull/99545. Once an implementation is decided, the evaluation mode validation function in the new tests added in this PR can be uncommented and should continue to pass. I also include one potential evaluation mode solution suggestion in this PR but leave the existing code unchanged since I know the team is intending to consider a range of solutions this week.

Test notes:
1. The 8 tests added here are a superset of the current `test_float16_on_one_submodule` tests, including validation of the following configurations: (`cast_root_forward_inputs_submodule` = True/False, `cast_forward_inputs_submodule` = True/False, `use_root_no_params` = True/False) across both training and evaluation modes.
2. The `float16_on_one_submodule` model configuration is currently only tested in the FSDP root module with parameters scenarios (as was the existing case) but this test can be easily extended to test it in the FSDP root module with no parameters scenarios as well if the team thinks the additional test resource usage is justified.
3. Since this test amortizes the cost of test setup across the aforementioned range of scenarios, the loop-based implementation of `dtype` validation (below) would have been undesirably complex IMHO[^1] :
```python
        ############### Logical equivalent of current test result matrix ############
        if self.cast_root_forward_inputs_submodule or self.cast_forward_inputs_submodule:
            self.assertEqual(self.forward_inputs[self.c2].dtype, torch.float16)
            if use_root_no_params:
                if self.cast_root_forward_inputs_submodule:
                    self.assertEqual(self.forward_inputs[self.model].dtype, torch.float16)
                else:
                    self.assertEqual(self.forward_inputs[self.model].dtype, torch.float32)
                self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float16)
            else:
                self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float32)
        else:
            self.assertEqual(self.forward_inputs[self.model].dtype, torch.float32)
            self.assertEqual(self.forward_inputs[self.c1].dtype, torch.float32)
            if not use_root_no_params: # this input will only exist in the root with params case until eval fix is applied
                self.assertEqual(self.forward_inputs[self.c2].dtype, torch.float32)
```
so I implemented the validation function as an expected result lookup that provides the added benefit of explicitly specifying the failed subtest upon failed `dtype` assertions, e.g.:
```python
AssertionError: None mismatch: torch.float32 is not None
Subtest `no_cast_root_no_cast_child_no_root_params` failed.
```
The potential solution to https://github.com/pytorch/pytorch/pull/99545 that I added as a suggestion in the file conversation passes this test set but I know there are a lot of different ways that it could be resolved so I'll assume that change will be tackled in a separate PR unless the team wants to include it in this one.

As mentioned, I've currently based this PR off of https://github.com/pytorch/pytorch/pull/100290 so am happy to either wait for that to land first or rebase this PR however the team wants.

[^1]: Batching the scenarios into different tests is also possible of course but would involve unnecessary test setup overhead, happy to switch to that approach if the team prefers that though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100349
Approved by: https://github.com/awgu
2023-05-12 04:23:18 +00:00
Chien-Chin Huang
0fbe55ea8f [FSDP][state_dict] Make sharded_state_dict work with composable fully_shard (#100856)
The current implementation of sharded_state_dict only works with wrapper based FSDP (both use_orig_params and not use_orig_params work) but not with fully_shard. This PR changes the implementation of sharded_state_dict when loading to fix the incompatibility.

Differential Revision: [D45626856](https://our.internmc.facebook.com/intern/diff/D45626856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45626856/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100856
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2023-05-10 15:32:45 +00:00
Chien-Chin Huang
55844dfdbc [FSDP][state_dict] Restore the state_dict_config for NO_SHARD (#100855)
Any change to the user configurations should be temporary. This PR fixes the issue when NO_SHARD state_dict/load_state_dict is called, the state_dict_config and state_dict_type are changed permanently.

Differential Revision: [D45593313](https://our.internmc.facebook.com/intern/diff/D45593313/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45593313/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100855
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao, https://github.com/rohan-varma
2023-05-10 10:01:21 +00:00
Rohan Varma
8869897ebe [replicate] support simpler device_id (#100217)
Allow passing in `device_id=[device]` regardless of CPU or GPU. We
modify the kwarg as needed to pass to DDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100217
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2023-05-04 21:06:04 +00:00
Rohan Varma
253b9d3247 [replicate] input casting support (#100216)
Supports input casting by doing this in the pre hook.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100216
Approved by: https://github.com/awgu
2023-05-04 01:46:15 +00:00
Yanli Zhao
dc9c79d3cf Allow each fully_shard unit to cast foward inputs for mixed precision config (#100290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100290
Approved by: https://github.com/rohan-varma
2023-05-02 00:03:48 +00:00
Iris
a23365885f [FSDP] Make set_state_type to SHARDED_STATE_DICT compatible with NO_SHARD sharding_strategy (#100208)
Currently, if we use NO_SHARD strategy for fully_shard and set state_dict_type to be SHARDED_STATE_DICT, a runtime error would be raised ("``sharded_state_dict`` can only be used when parameters are flatten and sharded.").

This PR updates pre_state_dict_hook, post_state_dict_hook, pre_load_state_dict_hook, and post_load_state_dict_hook to set state_dict_type and state_dict_config to full state when using NO_SHARD, even if the state_dict_type and state_dict_config of the root module is set to sharded state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100208
Approved by: https://github.com/rohan-varma
2023-04-28 04:37:58 +00:00
Yanli Zhao
6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00
Rohan Varma
ef11966aff [composable] Enable replicate + trec_shard overall (#98890)
replicate + trec_shard works if we shard / replicate individually, such as follows:

```
m = TestSparseNN()
shard(m.sparse)
replicate(m.dense)
```

but does not work if users do the following:
```
m = TestSparseNN()
shard(m, sharders=[...])
replicate(m)
```

Many upstream trainers use the latter use case, as sharding is not done on individual module level but rather overall module by specifying planners that contain logic for how to shard different embedding table types.

This diff enables the latter approach (while keeping the former intact), but users need to specify `ignored_modules` to ignore embedding tables in replicate(). This is similar to FSDP (class based and composable) and DDP today.

Differential Revision: [D44899155](https://our.internmc.facebook.com/intern/diff/D44899155/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98890
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
2023-04-15 01:09:00 +00:00
Yanli Zhao
cfd1b4df94 [Composable] add checking key for check_fqn function (#98961)
add checking key for check_fqn function

ghstack-source-id: d856f560f1fc449a316135e3844609d0baaf6d66
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96705

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98961
Approved by: https://github.com/awgu
2023-04-13 03:16:14 +00:00
BowenBao
60a68477a6 Bump black version to 23.1.0 (#96578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578
Approved by: https://github.com/ezyang
2023-03-15 06:27:59 +00:00
Xuehai Pan
046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Andrew Gu
3e4d0e8d82 [Reland][FSDP] Do not clean FQNs for use_orig_params=True (#92662)
The last PR (https://github.com/pytorch/pytorch/pull/91767/) had a land race relating to `_NamedOptimizer` + FSDP and got reverted. This is a re-land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92662
Approved by: https://github.com/rohan-varma
2023-01-30 16:07:44 +00:00
Andrew Gu
f659452009 [FSDP][1/N] Split fully_shard unit tests (#92296)
This PR splits `test_fully_shard.py` into `fully_shard/test_fully_shard<...>.py`. This should help improve readability and avoid some future rebase conflicts.

The only other real change is resolving a `TODO` for using `run_subtests` in the model checkpointing unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92296
Approved by: https://github.com/mrshenli
2023-01-20 02:02:59 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
PyTorch MergeBot
88942a3199 Revert "[FSDP] Do not clean FQNs even for use_orig_params=True (#91767)"
This reverts commit d6f3265e1a.

Reverted https://github.com/pytorch/pytorch/pull/91767 on behalf of https://github.com/malfet due to Looks like it broke `test_compatible_with_named_optimizer` distribued tests, see d6f3265e1a
2023-01-17 20:04:52 +00:00
Andrew Gu
d6f3265e1a [FSDP] Do not clean FQNs even for use_orig_params=True (#91767)
Cleaning FQN for `FullyShardedDataParallel(use_orig_params=True)` can cause some discrepancies with respect to the FQN compared to manually looping over `named_modules()` and `named_parameters()` together.

There is no requirement for the FQNs to be clean when using wrapper FSDP + `use_orig_params=True`. We can leave clean FQNs to `fully_shard`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91767
Approved by: https://github.com/zhaojuanmao
2023-01-17 17:41:28 +00:00
Chien-Chin Huang
1439cb0314 [FSDP][optim_state_dict][9/N] Rewrite the all-gather flow of optimizer state to support older GPUs (#91343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91343
Approved by: https://github.com/rohan-varma
2023-01-17 17:21:19 +00:00
PyTorch MergeBot
7bdcf6d4f0 Revert "[FSDP] Do not clean FQNs even for use_orig_params=True (#91767)"
This reverts commit a383789f4d.

Reverted https://github.com/pytorch/pytorch/pull/91767 on behalf of https://github.com/huydhn due to This breaks inductor_distributed workflow a383789f4d
2023-01-12 19:07:50 +00:00
Andrew Gu
a383789f4d [FSDP] Do not clean FQNs even for use_orig_params=True (#91767)
Cleaning FQN for `FullyShardedDataParallel(use_orig_params=True)` can cause some discrepancies with respect to the FQN compared to manually looping over `named_modules()` and `named_parameters()` together.

There is no requirement for the FQNs to be clean when using wrapper FSDP + `use_orig_params=True`. We can leave clean FQNs to `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91767
Approved by: https://github.com/zhaojuanmao
2023-01-12 15:14:14 +00:00
Chien-Chin Huang
0e8565d1d5 [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234)
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
2022-12-30 06:56:44 +00:00
Yanli Zhao
9b144ddbe4 Make input casting in root module only in default (#91365)
Make input casting in root module only in default, meanwhile allowing to set different mixed precisions for different submodules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91365
Approved by: https://github.com/awgu
2022-12-29 03:20:32 +00:00
Chien-Chin Huang
d08e3d2304 [Composable API] Apply ufmt to _composable and the corresponding test folders (#91255)
This PR apply ufmt to format `_composable` related code. This is a request from https://github.com/pytorch/pytorch/pull/91234 to separate formatting changes as a new PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91255
Approved by: https://github.com/awgu
2022-12-23 16:08:27 +00:00
Shen Li
a0554261a1 Restore RNG states for composable reentrant activation checkpointing (#91265)
This allows ops like randperm to behave the same during re-computation.

Differential Revision: [D42196758](https://our.internmc.facebook.com/intern/diff/D42196758/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91265
Approved by: https://github.com/awgu
2022-12-22 03:15:55 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
32fde53713 [FSDP][5/N] Add manual "wrapping" support for fully_shard (#90874)
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90874
Approved by: https://github.com/mrshenli
2022-12-20 16:49:15 +00:00
Rohan Varma
7330eabe36 fully_shard load state_dict (#90945)
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90945
Approved by: https://github.com/awgu
2022-12-20 07:26:43 +00:00
Charlie Yan
a1a2f548a9 [Composable API] Enable composable fully_shard submodules in replicate parent module (#90711)
To make sure `fully_shard` and `replicate` can work together, we need to check for each other in the implementation. This change adds the check in `replicate()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90711
Approved by: https://github.com/mrshenli
2022-12-17 09:28:38 +00:00
Rohan Varma
b92975a6f3 replicate state_dict tests (#90868)
Simple tests for replicate() state_dict. Ensuring composition with
FSDP works will come as a follow up.

Differential Revision: [D42048131](https://our.internmc.facebook.com/intern/diff/D42048131/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90868
Approved by: https://github.com/awgu
2022-12-15 14:53:24 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Andrew Gu
b3d49c2fb8 [FSDP][1/N] fully_shard state dict (#90767)
Co-authored with @rohan-varma.

**Overview**
This adds preliminary `state_dict()` support for `fully_shard`.
- The only explicit branching between composable and wrapper code paths happens in the state dict hook registration, which is inevitable.
- We introduce a `_comm_module_prefix` to match the FQNs between the two code paths. This is needed since for composable, the FQNs are prefixed from the local FSDP root, whereas for state dict purposes, we want them to be prefixed from the comm. module. Thus, we need this `_comm_module_prefix` to be stripped during state dict.
    - In my understanding, the alternative to not use the `prefix` argument in `state_dict()` does not support the case when `fully_shard` is applied to a submodule (i.e. not the global root module) since we still need _part_ of `prefix` then.

**Follow-Ups**
- We can retire the `functools.partial` usage once @fegin's PR lands.
- We should add more thorough testing (e.g. sharded state dict, save and load together etc.).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90767
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-12-13 20:05:40 +00:00