Commit Graph

118 Commits

Author SHA1 Message Date
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
Chien-Chin Huang
2ea38498b0 [FSDP][BE] Only show state_dict log when the debug level is detail (#118196)
As title

Differential Revision: [D53038704](https://our.internmc.facebook.com/intern/diff/D53038704/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118196
Approved by: https://github.com/rohan-varma, https://github.com/wz337
ghstack dependencies: #118197, #118195
2024-01-26 09:52:36 +00:00
Mihir Patel
84cfe6d8b2 Drop all gather stats to debug not warning (#117669)
Logger default level results in these all gather stats being spammed into every run which is very annoying

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117669
Approved by: https://github.com/Skylion007, https://github.com/awgu
2024-01-17 21:44:59 +00:00
Andrew Gu
92cc8ae172 [FSDP] Cloned unsharded tensor slice in optim state dict load (#117261)
This takes the fix from https://github.com/pytorch/pytorch/issues/116553. Cloning the slice allows the base (much larger) tensor to be freed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117261
Approved by: https://github.com/wz337
2024-01-11 20:21:12 +00:00
Aaron Gokaslan
ee5d981249 [BE]: Enable RUFF PERF402 and apply fixes (#115505)
* Enable PERF402. Makes code more efficient and succinct by removing useless list copies that could be accomplished either via a list constructor or extend call. All test cases have noqa added since performance is not as sensitive in that folder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115505
Approved by: https://github.com/malfet
2023-12-20 18:01:24 +00:00
Chien-Chin Huang
cc28f61fa3 [DCP][BE] Move DCP._state_dict_utils out from DCP (#115523)
DCP._state_dict_utils is also used by FSDP. This can cause circular import sometimes. Move it out from DCP to avoid circular import.

Differential Revision: [D52022440](https://our.internmc.facebook.com/intern/diff/D52022440/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115523
Approved by: https://github.com/wz337
2023-12-13 08:59:48 +00:00
Yue Dong
ab120e65fb Fix FSDP + TP state dict in param unflattening (#115105)
Summary:
This diff fix the param unflattening when using FSDP together with TP. Currently we hardcode the `reshape_size` to be multiplied by 2, which instead should be the size of the process group.

Before the fix, example exception: `shape '[257, 514]' is invalid for input of size 264196`, where the process group size is 4 instead of 2.

Test Plan:
**CI**:
CI test

**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/tensor/parallel:fsdp_2d_parallel`
- Passed

**Test model with WHEN**:
- Verified that checkpoint can be saved and resumed successfully;
- Verified the accuracy with window_ne, which is on-par with baseline.
https://pxl.cl/3Wp8w

Differential Revision: D51826120

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115105
Approved by: https://github.com/fegin
2023-12-05 21:19:56 +00:00
Chien-Chin Huang
4ba649e207 [FSDP][state_dict] Avoid assigning the root _device_mesh to the children _device_mesh (#114384)
Assigning the root _device_mesh to the children _device_mesh is not correct as each FSDP state can have a different DeviceMesh. We are also replacing fully_shard with a new implementation. So there is no need to worry about the fully_shard behavior.

Differential Revision: [D51507959](https://our.internmc.facebook.com/intern/diff/D51507959/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114384
Approved by: https://github.com/wz337
2023-11-30 02:08:31 +00:00
Jez Ng
5cfa0647a7 Update mypy to 1.7.0 (#114160)
It appears that `mypy` is now checking a few more previously-unchecked files; these files
are being found via import-following. Not sure exactly why they weren't being checked before.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114160
Approved by: https://github.com/eellison
ghstack dependencies: #114162
2023-11-28 06:45:55 +00:00
Gary Zheng
d1ae5efa94 [torch][fsdp] More informative assertion error when rank mismatch (#113765)
Summary: I had a job fail due to rank mismatch but didn't find enough information in the assertion message. This change makes the message more informative.

Test Plan:
CI tests and I ran a test job which failed as expected:

```
Rank 1 has different values for step: 8016.0. Other ranks: 7870.0
```

Differential Revision: D51322046

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113765
Approved by: https://github.com/wz337, https://github.com/fegin
2023-11-20 17:44:41 +00:00
wz337
ca9e654353 [FSDP] Fix FSDP submodule with DeviceMesh does not return DTensor state_dict error (#113593)
For scenarios where FSDP is not the root module, the `_use_dtensor` flag would not be switched on. This PR fixes it by checking whether the submodule has the `device_mesh` and turn `_use_dtensor` flag on accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113593
Approved by: https://github.com/fegin
2023-11-15 19:00:19 +00:00
Chien-Chin Huang
2bcff4d8e3 [state_dict][11/N] Implement cpu_offload and full_state_dict for get_state_dict (#112837)
As title

Differential Revision: [D50962991](https://our.internmc.facebook.com/intern/diff/D50962991/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112837
Approved by: https://github.com/LucasLLC, https://github.com/wz337
ghstack dependencies: #112836, #112885
2023-11-13 10:03:06 +00:00
wz337
31ded95cd5 [2D] Bind _fsdp_extension to FSDP instances (#113237)
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).

This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-11-09 03:31:03 +00:00
Chien-Chin Huang
a66f2a1b99 [state_dict] Move _gather_state_dict to dcp module (#112835)
This api is getting used by more than just FSDP. This PR moves it to DCP module.

Differential Revision: [D50962966](https://our.internmc.facebook.com/intern/diff/D50962966/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112835
Approved by: https://github.com/wz337
2023-11-08 19:42:56 +00:00
Chien-Chin Huang
a810126cf7 [FSDP][optim_state_dict] Skip the parameter if the parameter does not belong to the current FSDP instance (#112804)
Skip the fsdp managed parameter if the parameter is not managed by the current FSDP instance. This can happen if the not all FSDP instances have all the parameters. This can happen with FSDP + some MPMD style parallelism.

Differential Revision: [D50562170](https://our.internmc.facebook.com/intern/diff/D50562170/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112804
Approved by: https://github.com/wz337
2023-11-06 18:23:36 +00:00
Chien-Chin Huang
2a86bcbac2 [FSDP][state_dict] Cleanup the usage of _get_pg_default_device (#112168)
_get_pg_default_device is not suitable for FSDP use case. We should always use the compute_device when communicating.

Differential Revision: [D50698730](https://our.internmc.facebook.com/intern/diff/D50698730/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112168
Approved by: https://github.com/wz337
2023-10-27 08:09:08 +00:00
Iris Zhang
c84dbd2c03 [2D] Enable 2D optimizer set_state_dict() (#111778)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111778
Approved by: https://github.com/fegin, https://github.com/fduwjj
ghstack dependencies: #111774
2023-10-27 04:33:00 +00:00
wz337
8dc4887e84 [2D] Enable 2D optimizer get_state_dict() (#111774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111774
Approved by: https://github.com/fegin
2023-10-25 03:44:14 +00:00
Daniel Dale
90e2117a99 Allow optimizer state conversion to accommodate optimizers that have no tensor state (e.g. SGD) (#111501)
Fixes #111499

This PR slightly alters the new fused `all_gather` `optim_state_dict` implementation to support optimizers without tensor state (e.g. SGD) in a `use_orig_params=True` context.

The principle change is to short-circuit `_allgather_orig_param_states` if an empty `state_buffers` dict is returned after completing `_convert_all_state_info` here:
93e5065ba0/torch/distributed/fsdp/_optim_utils.py (L1481-L1484)

To allow `_convert_all_state_info` to accommodate optimizers with no tensor state, I also change the scope of `dtype` and make the return type `Optional`.

As discussed in the issue this PR fixes, I'm [extending](93e5065ba0/test/distributed/fsdp/test_fsdp_optim_state.py (L1915I)) `test_state_dict_with_none_tensor_state` to test with both Adam and SGD optimizers to validate scalar and non-tensor states continue to be restored for both optimizer types.

Thanks to the distributed team as always for their adroit design and exceptionally valuable contributions to the open source ML community. Hope you all feel appreciated commensurate with the compounding progress your work enables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111501
Approved by: https://github.com/fegin
2023-10-19 06:47:04 +00:00
Chien-Chin Huang
7b25c2b90e [FSDP][optim_state_dict] Move local optimizer state to FSDP compute_device (#110929)
This will ensure all the tensors are on FSDP compute_device.

Differential Revision: [D50059492](https://our.internmc.facebook.com/intern/diff/D50059492/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110929
Approved by: https://github.com/wz337
2023-10-10 10:34:31 +00:00
Chien-Chin Huang
90bf6e3938 [FSDP][optim_state_dict] Enable cpu_offload config for optimzer state_dict (#108434)
We had the option but never used cpu_offload as optimizer state_dict offloads the tensors to CPU by default. And this is usually most users want as the tensors are required to be moved to CPU eventually. However, we may want to disable offloading to CPU in some cases, epsecially for the debugging purpose. This PR lets optimizer state_dict read the flag.

Differential Revision: [D48913340](https://our.internmc.facebook.com/intern/diff/D48913340/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108434
Approved by: https://github.com/wz337
2023-10-07 01:14:49 +00:00
Chien-Chin Huang
1a729618ef [FSDP][optim_state_dict] Make the new optimizer allgather fusion work with fine-tuning models (#110540)
With use_orig_params=True, it is possible that some parameters with the same FlatParameter are in the optimizer while others parameters are frozen. This PR makes the allgather fusion logic support the case.

Differential Revision: [D49922028](https://our.internmc.facebook.com/intern/diff/D49922028/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110540
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2023-10-05 15:17:10 +00:00
Chien-Chin Huang
cdde899a73 [FSDP][optim_state_dict] Fuse allgather for optim_state_dict when use_orig_params is True (#108298)
The original implementation of `_gather_orig_param_state` is naive. It performs one allgather_object and two allgather (if the optimizer is Adam) per FQN. This can be slow and make `_optim_state_dict` become bottleneck.

This PR rewrite the implementation and fuse all the `allgather_object`s into one. As for `allgather`, it is fused based on the information of FlatParameters. So there will be 2N `allgather` where N is the number of FlatParameter and 2 is due to Adam having 2 states per FQN.

One experiment on 8GPU A100 shows that the execution of the gathering is improved to 0.3 seconds from 3 seconds.

Differential Revision: [D48835138](https://our.internmc.facebook.com/intern/diff/D48835138/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108298
Approved by: https://github.com/awgu
2023-10-02 20:57:08 +00:00
Matthew Hoffman
68b0db1274 Define the public API for torch.distributed.fsdp (#109922)
Related: https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
Related: https://github.com/microsoft/pylance-release/issues/2953

This fixes pylance issues for these classes:

```
"FullyShardedDataParallel" is not exported from module "torch.distributed.fsdp"
```

These classes all have public docs:

* [`BackwardPrefetch`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch)
* [`CPUOffload`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.CPUOffload)
* [`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel)
* [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
* [`ShardingStrategy`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy)

And it seems like all the newly added classes will have docs once they are released.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109922
Approved by: https://github.com/wanchaol
2023-09-28 02:15:58 +00:00
Chien-Chin Huang
1b3e5b53f3 [FSDP][optim_state_dict] Add device to _shard_utils.py to explicitly use the device from fsdp_state (#109631)
_get_pg_default_device does not always get the device we want. This PR let the user explicitly tell use the correct device.

Differential Revision: [D49425743](https://our.internmc.facebook.com/intern/diff/D49425743/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109631
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/wz337
2023-09-20 01:59:38 +00:00
wz337
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
Chien-Chin Huang
591cb776af [FSDP][state_dict][optim_state_dict] Log slow optim and model state_dict paths (#108290)
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.

This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.

Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
2023-09-01 06:57:59 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
Chien-Chin Huang
f6a9c15421 [FSDP][state_dict] Make optim_state_dict_to_load work with use_orig_param=False + NO_SHARD (#107185)
Summary: As title

Test Plan: CI

Reviewed By: wz337

Differential Revision: D48329724

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107185
Approved by: https://github.com/fegin
2023-08-15 21:42:41 +00:00
Jane Xu
7e47343d64 [BE] document more of FSDP checkpointing logic with a sprinkle of cleaning (#106069)
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
2023-08-02 17:19:04 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
Chien-Chin Huang
46154c4c35 [FSDP][optim_state_dict] The correct way to initialize optimizer states if the corresponding param is empty (#104765)
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.

Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
2023-07-10 08:00:55 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
Chien-Chin Huang
0ae4c4d417 [FSDP][optim_state_dict] Avoid calling optim.state_dict() to get the initial
empty states (#103609)

Users may prefix the keys optim state_dict. Using`optim.state_dict()` to get the initial states is brittle. This PR removes the call to `optim.state_dict()` and directly infers the empty states from the input states.

Differential Revision: [D46729119](https://our.internmc.facebook.com/intern/diff/D46729119/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103609
Approved by: https://github.com/awgu
2023-06-20 22:11:58 +00:00
Iris
7dd0f525b5 [FSDP][4/n]Update use_dtensor option for _optim_utils.py (#103599)
Same as https://github.com/pytorch/pytorch/pull/103069 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103599
Approved by: https://github.com/fegin
2023-06-14 20:18:33 +00:00
Rohan Varma
dfa64fddeb [FSDP] Fix for optim state dict (#102901)
Fix for HSDP + use_orig_params where we need to pass in the PG that
might not be the default.

Differential Revision: [D46417327](https://our.internmc.facebook.com/intern/diff/D46417327/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102901
Approved by: https://github.com/wz337
2023-06-06 20:21:23 +00:00
medivh-xp
8b7bd81902 determined collective device by _get_pg_default_device rather than explicit cuda (#101533)
There are many communication operations for shardedTensor in the state dict of fsdp. They use the external passed-in pg (or the default pg), which currently supports cuda devices. Before communication, the memory will be moved to cuda, which is implicit (because it is essentially moving data to the memory type required by pg, not the computing device type). Similarly, when users use fsdp on a custom backend, they will pass in a custom pg (which does not support cuda devices), which may cause fsdp to not work properly in some cases. This PR obtains the memory type supported by the pg through _get_pg_default_device during communication, and moves the data to it when needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101533
Approved by: https://github.com/awgu
2023-05-24 13:48:43 +00:00
Edward Z. Yang
f65732552e Support FakeTensor with FlatParameter (#101987)
In this PR we turn FlatParameter into a virtual tensor subclass
which doesn't actually ever get instantiated: __new__ will create
a Parameter instead (or a FakeTensor, if necessary).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101987
Approved by: https://github.com/awgu, https://github.com/eellison
2023-05-23 23:12:08 +00:00
Yanli Zhao
ca1cf434e7 Not flatten states when use_orig_param is True and sharding is NO_SHARD (#100189)
When use_orig_param is True and sharding is NO_SHARD, parameters and states are not flattened, so optimizer states should not be flattened as well. The unit test will fail without the fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100189
Approved by: https://github.com/awgu
2023-04-27 23:47:01 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Chien-Chin Huang
3de7fd461a [FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
2023-04-25 00:27:07 +00:00
Chien-Chin Huang
7876c503b7 [FSDP][optim_state_dict] Consolidate rank0_only load logic (#99647)
Follow up https://github.com/pytorch/pytorch/pull/99624, this PR consolidate the logic of `use_orig_params=False` with `use_orig_params=True` to use the same logic to load optimizer checkpoint when rank0_only is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99647
Approved by: https://github.com/wz337
2023-04-21 20:29:54 +00:00
Chien-Chin Huang
dd07dab1c7 [FSDP][optim_state_dict] Support rank0_only when use_orig_params is on (#99624)
This PR makes `use_orig_params=True` case support rank0_only loading for optim state_dict. The implementation is different from `use_orig_params=False`. The `use_orig_params=False` implementation first flatten the parameters on rank0 and then broadcast the states while this implementation broadcast the state when doing the flattening. The implementation is slower as it broadcast the original parameters instead of the flattened ones. However, the implementation introduced by this PR is simpler. As loading is usually happen once per training life, the performance difference can be ignored. In next PR, we will consolidate the implementations in favor of the simpleness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99624
Approved by: https://github.com/wz337
2023-04-21 20:09:19 +00:00
Iris
a2a4144256 [FSDP]Make param_groups optional for FSDP optim state dict (#99117)
Make param_groups optional for FSDP optim state dict and add corresponding test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99117
Approved by: https://github.com/fegin, https://github.com/zhaojuanmao
2023-04-20 06:34:40 +00:00