Commit Graph

100 Commits

Author SHA1 Message Date
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Albert Chen
7c8efc9049 [PT][FSDP] Combine _utils.py into _common_utils.py [2/2] (#106181)
Summary:
https://github.com/pytorch/pytorch/issues/97813
This diffs moves `_no_dispatch_record_stream` and `_same_storage_as_data_ptr`

Test Plan: CI

Differential Revision: D47706114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106181
Approved by: https://github.com/awgu
2023-07-28 17:15:25 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
035704e88d [FSDP][Easy] Move post-bwd hook logging to own func (#106032)
This is to help make `_post_backward_hook()` easier to read. I plan to refactor some other parts in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106032
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Daniel Dale
6b6702f506 Enhance no_grad-context FSDP backward handling (#105374)
Fixes #105369
Fixes #105371

Addressing two somewhat distinct issues that involve the same test in this PR:

1. To fix #105369:
    - Add a `no_grad` guard to [`_register_post_backward_reshard_only_hooks`](93f852f201/torch/distributed/fsdp/_runtime_utils.py (L1406)) to avoid registering post-backward hooks that would not be removed in that context.

2. To fix #105371:
    - Add a `grad` context condition to [`_use_sharded_flat_param`](93f852f201/torch/distributed/fsdp/flat_param.py (L1645C9-L1645C32)) logic to trigger post-forward `_use_sharded_views` in a `no_grad` context for `NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105374
Approved by: https://github.com/awgu
2023-07-26 14:12:13 +00:00
Andrew Gu
c099b80073 [FSDP] Add record_function for explicit prefetching (#105985)
Example:
<img width="568" alt="Screenshot 2023-07-25 at 7 41 43 PM" src="https://github.com/pytorch/pytorch/assets/31054793/5f3f07b3-97f4-4493-9cab-5619484e2f6d">

This can be particularly help when `with_stack=False`, in which case it is harder to tell the prefetch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105985
Approved by: https://github.com/fegin
2023-07-26 12:16:35 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Rohan Varma
a326f5621e composable fsdp, checkpoint, + compile test (#105180)
Test to ensure that composable FSDP, checkpoint, and compile all work
together. Includes a change from https://github.com/pytorch/pytorch/pull/105090
which we can land in that PR first.

Differential Revision: [D47452973](https://our.internmc.facebook.com/intern/diff/D47452973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105180
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Andrew Gu
63d1fb21f5 [FSDP] Default limit_all_gathers=True (#104900)
This PR defaults to `limit_all_gathers=True`.

I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
2023-07-11 01:04:29 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Iris
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
Wanchao Liang
8457703e8d lazy init device mesh in fsdp (#104447)
since fsdp state is lazy init, we also need to lazy init device mesh
otherwise devicemesh allgather check would trigger some mismatch in
allgather counts in fsdp tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104447
Approved by: https://github.com/wconstab
2023-06-30 04:40:16 +00:00
Rohan Varma
c866446d6c [FSDP] Check module.training for _root_cast_forward_inputs (#104223)
We might erroneously cast forward inputs for the root if it doesn't
manage any handles (FSDP parameters). As a fix, pass in the module and check
its training attribute to ensure we don't cast inputs in eval mode.

Differential Revision: [D47041673](https://our.internmc.facebook.com/intern/diff/D47041673/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104223
Approved by: https://github.com/fegin
2023-06-28 16:38:01 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
shibo19
c2095af3f8 make funcs argument type from torch.cuda.stream as torch.Stream (#104156)
Fixes #ISSUE_NUMBER
1. we want to support fsdp for custom device, so we make funcs argument type from torch.cuda.stream as torch.Stream
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104156
Approved by: https://github.com/awgu
2023-06-28 06:02:56 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
48056b168f [FSDP] Reshard frozen params in backward (#101982)
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
2023-06-08 21:12:45 +00:00
Iris
d5142c52d3 [FSDP]Remove dim_group from device_mesh init (#103218)
1) remove dim_group
2) don't init device_mesh if not using default_pg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103218
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-08 03:29:19 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00
Rohan Varma
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
Rohan Varma
e66c498d2d Log modules FSDP hooks fire for (#102508)
Under torch_distributed_debug >= INFO and use_orig_params=True, log post backward hook firing to debug things like FSDP + AC integration.

Differential Revision: [D46172916](https://our.internmc.facebook.com/intern/diff/D46172916/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102508
Approved by: https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:12 +00:00
PyTorch MergeBot
81ac076bce Revert "[FSDP]Add device_mesh to FSDPstate (#102317)"
This reverts commit 4c584acc5d.

Reverted https://github.com/pytorch/pytorch/pull/102317 on behalf of https://github.com/malfet due to Broke test_fake_pg, see https://github.com/pytorch/pytorch/actions/runs/5100633726/jobs/9173277369  ([comment](https://github.com/pytorch/pytorch/pull/102317#issuecomment-1566129496))
2023-05-28 12:53:28 +00:00
Iris
4c584acc5d [FSDP]Add device_mesh to FSDPstate (#102317)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102317
Approved by: https://github.com/awgu
2023-05-27 20:25:30 +00:00
Yanli Zhao
dc9c79d3cf Allow each fully_shard unit to cast foward inputs for mixed precision config (#100290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100290
Approved by: https://github.com/rohan-varma
2023-05-02 00:03:48 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Daniel Dale
363d530035 Fix decision logic for should_cast_forward_inputs in _root_pre_forward() and _pre_forward() (#99546)
Fixes #99545

There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the  ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1].

For [``_root_pre_forward()``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L514)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L602-L604)

For [``_pre_forward``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L384)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L420-L422)

See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction.

### Remediation

In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles:
```python
should_cast_forward_inputs = len(state._handles) > 0 and all(
    not handle._force_full_precision for handle in state._handles
)
```

If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error.

Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with.

### Implementation considerations and questions:

1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so.
2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so.
3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it.
cde35b4069/torch/distributed/fsdp/api.py (L175-L181)

Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!

[^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546
Approved by: https://github.com/awgu
2023-04-21 22:49:50 +00:00
Rohan Varma
d8b09b0139 [FSDP] Full precision in eval mode (#97645)
If model.eval() is true, then runs the model in full precision.

Changes:
- Changed _force_full_precision to check self.is_training
- Check for _force_full_precision when casting gradients to reduced dtype
- Small change when accessing _full_prec_param_padded
- tests for class based and fully_shard APIs

Differential Revision: [D43933690](https://our.internmc.facebook.com/intern/diff/D43933690/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97645
Approved by: https://github.com/awgu
2023-04-13 18:38:22 +00:00
feifan
d95ee64b58 ddp forward support custom backend. (#98283)
Currently DDP only considers CUDA backend,DDP forward will transfer tensor to CUDA. We want ddp to run on custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98283
Approved by: https://github.com/ezyang
2023-04-09 01:30:42 +00:00
Rohan Varma
428c531d00 [FSDP] records for composable (#98428)
Add some function recording since composable API does record a FSDP.forward

Differential Revision: [D44715137](https://our.internmc.facebook.com/intern/diff/D44715137/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98428
Approved by: https://github.com/awgu
2023-04-06 06:40:48 +00:00
Andrew Gu
10271a60a8 [FSDP] Skip _use_sharded_views() for SHARD_GRAD_OP (#98250)
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
2023-04-04 17:07:28 +00:00
Andrew Gu
0b31f87c18 [FSDP] Use correct handle training state when prefetching (#98249)
This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98249
Approved by: https://github.com/rohan-varma
2023-04-04 13:34:02 +00:00
Andrew Gu
fb7b398479 [FSDP] Do not _unshard if already prefetched (#97981)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97981
Approved by: https://github.com/fegin
2023-03-31 18:47:03 +00:00
Andrew Gu
195b92ab01 [FSDP][Easy] Minor cleanups to _runtime_utils.py (#97980)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97980
Approved by: https://github.com/H-Huang
2023-03-31 18:47:03 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Rohan Varma
308a58ebca [FSDP] Rename to _get_orig_buffer_dtypes (#96790)
Reland this PR

Differential Revision: [D44078430](https://our.internmc.facebook.com/intern/diff/D44078430/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96790
Approved by: https://github.com/awgu
2023-03-16 00:31:29 +00:00
Xuehai Pan
80e8e41ca7 Fix type hint for torch.Tensor.grad_fn (#96804)
Fix type hint for `torch.Tensor.grad_fn`, which can be a `torch.autograd.graph.Node` or `None`.

This is a regression in `torch` 2.0. It makes `mypy` failure in downstream projects.

Ref:

- https://github.com/pytorch/pytorch/issues/94937#issuecomment-1469344993
- metaopt/torchopt#149
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96804
Approved by: https://github.com/Skylion007
2023-03-15 17:14:05 +00:00
Andrew Gu
6c30dc6cee [FSDP] Save _all_handles; _all_fsdp_states to root (#95465)
- The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings.
- This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state.
- This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states.
    - I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`.
    - This hopefully also helps readers to quickly see which attributes are defined only on root states.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465
Approved by: https://github.com/fduwjj
2023-02-26 13:59:53 +00:00