Commit Graph

85 Commits

Author SHA1 Message Date
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Andrew Gu
63d1fb21f5 [FSDP] Default limit_all_gathers=True (#104900)
This PR defaults to `limit_all_gathers=True`.

I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
2023-07-11 01:04:29 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Iris
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
Wanchao Liang
8457703e8d lazy init device mesh in fsdp (#104447)
since fsdp state is lazy init, we also need to lazy init device mesh
otherwise devicemesh allgather check would trigger some mismatch in
allgather counts in fsdp tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104447
Approved by: https://github.com/wconstab
2023-06-30 04:40:16 +00:00
Rohan Varma
c866446d6c [FSDP] Check module.training for _root_cast_forward_inputs (#104223)
We might erroneously cast forward inputs for the root if it doesn't
manage any handles (FSDP parameters). As a fix, pass in the module and check
its training attribute to ensure we don't cast inputs in eval mode.

Differential Revision: [D47041673](https://our.internmc.facebook.com/intern/diff/D47041673/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104223
Approved by: https://github.com/fegin
2023-06-28 16:38:01 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
shibo19
c2095af3f8 make funcs argument type from torch.cuda.stream as torch.Stream (#104156)
Fixes #ISSUE_NUMBER
1. we want to support fsdp for custom device, so we make funcs argument type from torch.cuda.stream as torch.Stream
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104156
Approved by: https://github.com/awgu
2023-06-28 06:02:56 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
48056b168f [FSDP] Reshard frozen params in backward (#101982)
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
2023-06-08 21:12:45 +00:00
Iris
d5142c52d3 [FSDP]Remove dim_group from device_mesh init (#103218)
1) remove dim_group
2) don't init device_mesh if not using default_pg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103218
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-08 03:29:19 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00
Rohan Varma
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
Rohan Varma
e66c498d2d Log modules FSDP hooks fire for (#102508)
Under torch_distributed_debug >= INFO and use_orig_params=True, log post backward hook firing to debug things like FSDP + AC integration.

Differential Revision: [D46172916](https://our.internmc.facebook.com/intern/diff/D46172916/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102508
Approved by: https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:12 +00:00
PyTorch MergeBot
81ac076bce Revert "[FSDP]Add device_mesh to FSDPstate (#102317)"
This reverts commit 4c584acc5d.

Reverted https://github.com/pytorch/pytorch/pull/102317 on behalf of https://github.com/malfet due to Broke test_fake_pg, see https://github.com/pytorch/pytorch/actions/runs/5100633726/jobs/9173277369  ([comment](https://github.com/pytorch/pytorch/pull/102317#issuecomment-1566129496))
2023-05-28 12:53:28 +00:00
Iris
4c584acc5d [FSDP]Add device_mesh to FSDPstate (#102317)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102317
Approved by: https://github.com/awgu
2023-05-27 20:25:30 +00:00
Yanli Zhao
dc9c79d3cf Allow each fully_shard unit to cast foward inputs for mixed precision config (#100290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100290
Approved by: https://github.com/rohan-varma
2023-05-02 00:03:48 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Daniel Dale
363d530035 Fix decision logic for should_cast_forward_inputs in _root_pre_forward() and _pre_forward() (#99546)
Fixes #99545

There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the  ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1].

For [``_root_pre_forward()``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L514)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L602-L604)

For [``_pre_forward``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L384)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L420-L422)

See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction.

### Remediation

In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles:
```python
should_cast_forward_inputs = len(state._handles) > 0 and all(
    not handle._force_full_precision for handle in state._handles
)
```

If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error.

Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with.

### Implementation considerations and questions:

1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so.
2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so.
3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it.
cde35b4069/torch/distributed/fsdp/api.py (L175-L181)

Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!

[^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546
Approved by: https://github.com/awgu
2023-04-21 22:49:50 +00:00
Rohan Varma
d8b09b0139 [FSDP] Full precision in eval mode (#97645)
If model.eval() is true, then runs the model in full precision.

Changes:
- Changed _force_full_precision to check self.is_training
- Check for _force_full_precision when casting gradients to reduced dtype
- Small change when accessing _full_prec_param_padded
- tests for class based and fully_shard APIs

Differential Revision: [D43933690](https://our.internmc.facebook.com/intern/diff/D43933690/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97645
Approved by: https://github.com/awgu
2023-04-13 18:38:22 +00:00
feifan
d95ee64b58 ddp forward support custom backend. (#98283)
Currently DDP only considers CUDA backend,DDP forward will transfer tensor to CUDA. We want ddp to run on custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98283
Approved by: https://github.com/ezyang
2023-04-09 01:30:42 +00:00
Rohan Varma
428c531d00 [FSDP] records for composable (#98428)
Add some function recording since composable API does record a FSDP.forward

Differential Revision: [D44715137](https://our.internmc.facebook.com/intern/diff/D44715137/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98428
Approved by: https://github.com/awgu
2023-04-06 06:40:48 +00:00
Andrew Gu
10271a60a8 [FSDP] Skip _use_sharded_views() for SHARD_GRAD_OP (#98250)
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
2023-04-04 17:07:28 +00:00
Andrew Gu
0b31f87c18 [FSDP] Use correct handle training state when prefetching (#98249)
This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98249
Approved by: https://github.com/rohan-varma
2023-04-04 13:34:02 +00:00
Andrew Gu
fb7b398479 [FSDP] Do not _unshard if already prefetched (#97981)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97981
Approved by: https://github.com/fegin
2023-03-31 18:47:03 +00:00
Andrew Gu
195b92ab01 [FSDP][Easy] Minor cleanups to _runtime_utils.py (#97980)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97980
Approved by: https://github.com/H-Huang
2023-03-31 18:47:03 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Rohan Varma
308a58ebca [FSDP] Rename to _get_orig_buffer_dtypes (#96790)
Reland this PR

Differential Revision: [D44078430](https://our.internmc.facebook.com/intern/diff/D44078430/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96790
Approved by: https://github.com/awgu
2023-03-16 00:31:29 +00:00
Xuehai Pan
80e8e41ca7 Fix type hint for torch.Tensor.grad_fn (#96804)
Fix type hint for `torch.Tensor.grad_fn`, which can be a `torch.autograd.graph.Node` or `None`.

This is a regression in `torch` 2.0. It makes `mypy` failure in downstream projects.

Ref:

- https://github.com/pytorch/pytorch/issues/94937#issuecomment-1469344993
- metaopt/torchopt#149
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96804
Approved by: https://github.com/Skylion007
2023-03-15 17:14:05 +00:00
Andrew Gu
6c30dc6cee [FSDP] Save _all_handles; _all_fsdp_states to root (#95465)
- The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings.
- This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state.
- This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states.
    - I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`.
    - This hopefully also helps readers to quickly see which attributes are defined only on root states.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465
Approved by: https://github.com/fduwjj
2023-02-26 13:59:53 +00:00
Andrew Gu
9c45f47bbe [FSDP] Save _fsdp_states on root (#95343)
This saves an attribute `_fsdp_states: Optional[_FSDPState]`. For root, it is populated with all `_FSDPState`s in the root's tree. For non-root, it is `None`.

This is used to avoid doing the tree traversal during `_root_pre_forward()` when `forward_prefetch=True`.

Differential Revision: [D43536895](https://our.internmc.facebook.com/intern/diff/D43536895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95343
Approved by: https://github.com/fegin
2023-02-23 21:18:05 +00:00
Andrew Gu
78175ceeab [FSDP][Docs] Re-add why reg. post-bwd hook on 1st forward (#95326)
This PR adds back some explanation for why we have the heuristic to only register the post-backward hook on the first forward in the case of multiple forwards.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95326
Approved by: https://github.com/fegin
2023-02-23 01:50:25 +00:00
Rohan Varma
c43e88665a [Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
2023-02-17 18:24:20 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
soulitzer
388b245d54 Expose autograd.graph.Node as an abstract base class (#91475)
This PR:
- registers all of the codegened Nodes to the torch._C._functions module, this is where special nodes like AccumulateGrad are already registered.
- creates a autograd.graph.Node abstract base class that all of the newly registered nodes subclass from. We make the subclassing happen by implementing the ``__subclasshook__`` method
- enables static type checking to work and also enables Sphinx to generate documentation for the Node and its methods
- handles both the custom Function and codegened cases

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91475
Approved by: https://github.com/albanD
2023-01-18 00:20:13 +00:00
Andrew Gu
b0888cce0f [FSDP][BE] Better error msg for incorrect device for training (#92027)
Closes https://github.com/pytorch/pytorch/issues/90541.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92027
Approved by: https://github.com/zhaojuanmao
2023-01-16 02:38:07 +00:00
Yanli Zhao
9b144ddbe4 Make input casting in root module only in default (#91365)
Make input casting in root module only in default, meanwhile allowing to set different mixed precisions for different submodules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91365
Approved by: https://github.com/awgu
2022-12-29 03:20:32 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
32fde53713 [FSDP][5/N] Add manual "wrapping" support for fully_shard (#90874)
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90874
Approved by: https://github.com/mrshenli
2022-12-20 16:49:15 +00:00
Andrew Gu
da9af9868e [FSDP][4/N] Refactor func to share state/init handle attrs (#90871)
For `limit_all_gathers`, if we do not enforce that they all have the same value, then the entire semantics guaranteed by the `bool` can be violated. It could be as if none of them set that value to be `True`.

For `use_orig_params`, optimizer state dict assumes that the value is the same for all FSDP instances.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90871
Approved by: https://github.com/mrshenli
2022-12-20 16:49:13 +00:00
Yanli Zhao
50ab2b702f move inputs to device on root module only (#91078)
1. No need to move inputs/activations to devices for every nested FSDP instance
2. it also breaks the case when some nested FSDP instances have newly added inputs/activations in the signatures of submodules wrapped by nested FSDP instances, args_tuple[0] and kargs_tuple[0] are not correct to get the inputs/activations for these nested instances

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91078
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-19 17:49:05 +00:00
Andrew Gu
5ea418bf63 [FSDP][3/N] Move fsdp_modules(root_only=True) -> _get_fsdp_root_states() (#90862)
- This PR introduces `_get_fsdp_root_states(state: _FSDPState, module: nn.Module)` to return all states that are FSDP root in the module tree rooted at `module`.
   - This requires passing in both `state` and `module` because it must call `_lazy_init()` to check for root-ness, which requires that signature.
- This PR moves the one internal usage of `FullyShardedDataParallel.fsdp_modules(root_only=True)` to use `_get_fsdp_root_states()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90862
Approved by: https://github.com/rohan-varma
2022-12-16 21:27:27 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00