Commit Graph

91 Commits

Author SHA1 Message Date
Chien-Chin Huang
591cb776af [FSDP][state_dict][optim_state_dict] Log slow optim and model state_dict paths (#108290)
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.

This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.

Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
2023-09-01 06:57:59 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
Chien-Chin Huang
f6a9c15421 [FSDP][state_dict] Make optim_state_dict_to_load work with use_orig_param=False + NO_SHARD (#107185)
Summary: As title

Test Plan: CI

Reviewed By: wz337

Differential Revision: D48329724

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107185
Approved by: https://github.com/fegin
2023-08-15 21:42:41 +00:00
Jane Xu
7e47343d64 [BE] document more of FSDP checkpointing logic with a sprinkle of cleaning (#106069)
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
2023-08-02 17:19:04 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
Chien-Chin Huang
46154c4c35 [FSDP][optim_state_dict] The correct way to initialize optimizer states if the corresponding param is empty (#104765)
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.

Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
2023-07-10 08:00:55 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
Chien-Chin Huang
0ae4c4d417 [FSDP][optim_state_dict] Avoid calling optim.state_dict() to get the initial
empty states (#103609)

Users may prefix the keys optim state_dict. Using`optim.state_dict()` to get the initial states is brittle. This PR removes the call to `optim.state_dict()` and directly infers the empty states from the input states.

Differential Revision: [D46729119](https://our.internmc.facebook.com/intern/diff/D46729119/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103609
Approved by: https://github.com/awgu
2023-06-20 22:11:58 +00:00
Iris
7dd0f525b5 [FSDP][4/n]Update use_dtensor option for _optim_utils.py (#103599)
Same as https://github.com/pytorch/pytorch/pull/103069 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103599
Approved by: https://github.com/fegin
2023-06-14 20:18:33 +00:00
Rohan Varma
dfa64fddeb [FSDP] Fix for optim state dict (#102901)
Fix for HSDP + use_orig_params where we need to pass in the PG that
might not be the default.

Differential Revision: [D46417327](https://our.internmc.facebook.com/intern/diff/D46417327/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102901
Approved by: https://github.com/wz337
2023-06-06 20:21:23 +00:00
medivh-xp
8b7bd81902 determined collective device by _get_pg_default_device rather than explicit cuda (#101533)
There are many communication operations for shardedTensor in the state dict of fsdp. They use the external passed-in pg (or the default pg), which currently supports cuda devices. Before communication, the memory will be moved to cuda, which is implicit (because it is essentially moving data to the memory type required by pg, not the computing device type). Similarly, when users use fsdp on a custom backend, they will pass in a custom pg (which does not support cuda devices), which may cause fsdp to not work properly in some cases. This PR obtains the memory type supported by the pg through _get_pg_default_device during communication, and moves the data to it when needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101533
Approved by: https://github.com/awgu
2023-05-24 13:48:43 +00:00
Edward Z. Yang
f65732552e Support FakeTensor with FlatParameter (#101987)
In this PR we turn FlatParameter into a virtual tensor subclass
which doesn't actually ever get instantiated: __new__ will create
a Parameter instead (or a FakeTensor, if necessary).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101987
Approved by: https://github.com/awgu, https://github.com/eellison
2023-05-23 23:12:08 +00:00
Yanli Zhao
ca1cf434e7 Not flatten states when use_orig_param is True and sharding is NO_SHARD (#100189)
When use_orig_param is True and sharding is NO_SHARD, parameters and states are not flattened, so optimizer states should not be flattened as well. The unit test will fail without the fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100189
Approved by: https://github.com/awgu
2023-04-27 23:47:01 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Chien-Chin Huang
3de7fd461a [FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
2023-04-25 00:27:07 +00:00
Chien-Chin Huang
7876c503b7 [FSDP][optim_state_dict] Consolidate rank0_only load logic (#99647)
Follow up https://github.com/pytorch/pytorch/pull/99624, this PR consolidate the logic of `use_orig_params=False` with `use_orig_params=True` to use the same logic to load optimizer checkpoint when rank0_only is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99647
Approved by: https://github.com/wz337
2023-04-21 20:29:54 +00:00
Chien-Chin Huang
dd07dab1c7 [FSDP][optim_state_dict] Support rank0_only when use_orig_params is on (#99624)
This PR makes `use_orig_params=True` case support rank0_only loading for optim state_dict. The implementation is different from `use_orig_params=False`. The `use_orig_params=False` implementation first flatten the parameters on rank0 and then broadcast the states while this implementation broadcast the state when doing the flattening. The implementation is slower as it broadcast the original parameters instead of the flattened ones. However, the implementation introduced by this PR is simpler. As loading is usually happen once per training life, the performance difference can be ignored. In next PR, we will consolidate the implementations in favor of the simpleness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99624
Approved by: https://github.com/wz337
2023-04-21 20:09:19 +00:00
Iris
a2a4144256 [FSDP]Make param_groups optional for FSDP optim state dict (#99117)
Make param_groups optional for FSDP optim state dict and add corresponding test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99117
Approved by: https://github.com/fegin, https://github.com/zhaojuanmao
2023-04-20 06:34:40 +00:00
Yanli Zhao
6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00
Chien-Chin Huang
bdaf32261f [FSDP] Ensure that customized non tensor optimizer state can be saved (#99214)
The current logic does not actually handle all different non-tensor optimizer states correctly. This PR fixes the issue and adds a test.

This PR will solve https://github.com/pytorch/pytorch/issues/99079

Differential Revision: [D45021331](https://our.internmc.facebook.com/intern/diff/D45021331/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99214
Approved by: https://github.com/awgu, https://github.com/awaelchli
2023-04-17 21:54:16 +00:00
Nikita Shulga
ccc5d1daec Revert D44897935: Multisect successfully blamed D44897935 for test or build failures (#99353)
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:

Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)

Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: fegin

Differential Revision: D45027286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
2023-04-17 20:53:10 +00:00
Chien-Chin Huang
8e328762ff [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
2023-04-13 20:37:11 +00:00
Andrew Gu
662a8cf74d [FSDP][8/N] Simplify addr padding internals (#97796)
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner.

**Details**
Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors.
- `_numels_with_padding`: length `N + M`
- `_is_padding_mask`: length `N + M`
- `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N`

`_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter.

This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now.

For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796
Approved by: https://github.com/rohan-varma
2023-03-28 22:19:44 +00:00
Andrew Gu
1c15cd48e2 [FSDP][7/N] Add alignment padding for use_orig_params=True (#97667)
This PR adds intra-`FlatParameter` 16-byte alignment padding to the `use_orig_params=True` code path to avoid clones in TorchInductor.

**Approach**
The `FlatParameter` maintains several data structures about its original parameters. Notably, the data structures `_param_infos`, `_shapes`, `_numels`, and `_fqns` have the same length and index in the same way.

This PR treats alignment padding _like_ an original parameter in that the padding gets flattened into the `FlatParameter`. Therefore, it must be reflected in the aforementioned data structures. However, given the way in which the data structures are used, we choose to do the following if the `i`th tensor flattened into the `FlatParameter` is padding:
- `_numels[i]` is the numel of padding
- `_param_infos[i] == _shapes[i] == _fqns[i] == None`

This choice is because (1) we must record the padding numel to account for it (e.g. for views) and (2) we prefer to preserve the invariant that the data structures index in the same way over avoiding `None` entries.

To ease the burden of other FSDP developers, we separate the parameter flattening logic:
- `_init_flat_param_and_metadata()`: This should be called only once in the `FlatParamHandle` constructor. The `FlatParameter` metadata is assumed to be static thereafter.
- `flatten_tensors()` / `flatten_tensors_into_flat_param()`: These can be used for optimizer and model state dict and can be called after construction time.

This separation allows `_init_flat_param_and_metadata()` to contain the much heavier metadata logic, while keeping the latter methods to be much lighter. The only constraint is that the alignment padding logic must be kept consistent between the two, but this should be worth the simper interface.

**Testing**
- This PR directly modifies the `use_orig_params=True` code path, so all existing tests passing gives good signal.
    - Some existing unit tests had to be adjusted to account for the alignment padding.
- This PR adds two tests in `test_fsdp_flatten_params.py` to explicitly test the sharding metadata with alignment for both parameter full precision and mixed precision since the latter requires possibly more padding elements due to the decreased per-element size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97667
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
b9049a7f11 [FSDP][6/N] Rename param/module name helpers for clarity (#97666)
This is an easy PR. It has some remaining local changes that I had that I felt clarified naming.
- `_param_fqns` -> `_param_name_infos` since it returns a tuple of `fqn, param_name, module_name`, not only `fqn`. (similarly for `_shared_param_fqns` -> `_shared_param_name_infos`)
- nit: `parameter_module_names` -> `param_module_names` for consistency since we almost never fully spell out `parameter`. (similarly for `shared_parameter_module_names` -> `shared_param_module_names`)
- nit: `full_fqn` -> `fqn_from_global_root`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97666
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
30a6ed34a0 [FSDP][5/N] Lift FSDPParamInfo to use FlatParamHandle (#97665)
This PR changes `FSDPParamInfo` in `_optim_utils.py` to save the `FlatParamHandle`, not directly the `FlatParameter`. This is in preparation for subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97665
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
c622559968 [FSDP][3/N] Minor fixes (rename, assert message) (#97663)
This is an easy PR.
- It renames `_shard_indices` to `_shard_param_indices` for consistency.
- It fixes an old mention of `comm_module` in an assert message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97663
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
a27882ecd1 [FSDP][2/N] Rename "flattened parameter" -> "flat parameter" (pt. 2) (#97662)
From our recent experience, we refer to FSDP's `FlatParameter` as "flat parameter", not "flattened parameter". This PR renames that in `_optim_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97662
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Chien-Chin Huang
793cb3f424 [FSDP][optim_state_dict] Print out more useful error message for optim_state_dict (#96860)
Summary: Print out more useful error message for optim_state_dict

Test Plan: CI

Reviewed By: wz337

Differential Revision: D43556073

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96860
Approved by: https://github.com/rohan-varma, https://github.com/wz337
2023-03-21 01:04:24 +00:00
Aaron Gokaslan
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
Chien-Chin Huang
15e58c19ec [FSDP][optim_state_dict] Copy step tensor so that each parameter has its own step (#96313)
Summary: When parameters are flattening, multiple parameters share the same step. When unflattening the parameters, current implementation still make these parameters share the same step. When this is not wrong, some training infra get confused by sharing tensor storages. This PR fixes the issue.

Test Plan: CI

Reviewed By: awgu

Differential Revision: D43893592

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96313
Approved by: https://github.com/zhaojuanmao
2023-03-10 04:51:30 +00:00
Chien-Chin Huang
92edac72aa [FSDP][optim_state_dict] Fix a memory leakage in optim_state_dict (#96263)
Summary: The original code uses a class variable to store flat_parameter result. This could cause memory leakage.

Test Plan: CI and a E2E run

Reviewed By: awgu

Differential Revision: D43893577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96263
Approved by: https://github.com/zhaojuanmao
2023-03-08 08:43:42 +00:00
Colin Taylor
16a4579335 [FSDP] [composable] [BE] warning should read TorchRec, not DMP (#95010)
Summary: as title

Test Plan: N/A

Differential Revision: D43375189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95010
Approved by: https://github.com/awgu, https://github.com/fegin
2023-02-17 03:31:30 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Aaron Gokaslan
3d82d8d0ed [BE] Enable more flake8-comprehensions checks (#94601)
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR.

This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601
Approved by: https://github.com/ezyang
2023-02-10 23:40:29 +00:00
Chien-Chin Huang
2180a0dc0c [FSDP][optim_state_dict] Remove the dead code (#94448)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94448
Approved by: https://github.com/awgu
2023-02-09 06:32:40 +00:00
Aaron Gokaslan
3ce1ebb6fb Apply some safe comprehension optimizations (#94323)
Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
2023-02-07 23:53:46 +00:00
Chien-Chin Huang
ab4fe01e72 [FSDP][optim_state_dict] Returns the initial states of the empty parameters for KeyedOptimizer/NamedOptimizer (#94130)
KeyedOptimizer and NamedOptimizer expect the states exist in the state_dict when `load_state_dict` is called even if the corresponding parameters are empty (size == 0). This PR adds the support to make KeyedOptimizer work with `use_orig_params=True`.

Differential Revision: [D43019458](https://our.internmc.facebook.com/intern/diff/D43019458/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94130
Approved by: https://github.com/rohan-varma
2023-02-07 23:36:56 +00:00
Chien-Chin Huang
bc6d54f6d8 [FSDP][optim_state_dict] Let optim_state_dict ignore the non-FSDP managed parameters that do not reside on the rank (#94129)
When FSDP is used with other parallelism (e.g., TorchRec), some parameters that are not managed by FSDP may not reside on all the ranks (TorchRec is model parallelism). When `use_orig_params=True` , FSDP will synchronize the FQNs among ranks. As a result, a rank may get the FQNs that the rank does not actually own. If the FQN belongs to a TorchRec managed parameter, FSDP has to ignore the parameter state. Otherwise FSDP does not know how to store the state.

This PR add the logic to ignore the parameters that are not managed by FSDP and are not on the rank.

Differential Revision: [D42982778](https://our.internmc.facebook.com/intern/diff/D42982778/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94129
Approved by: https://github.com/rohan-varma
2023-02-07 06:29:28 +00:00
Chien-Chin Huang
0f5b6caa16 [FSDP][optim_state_dict] Ignore the state check on rank that does not own the corresponding parameter (#93318)
When a rank does not own a parameter (parameter.numel() == 0), its optim state is not valid and should not be checked against the current saved one.

Differential Revision: [D42865237](https://our.internmc.facebook.com/intern/diff/D42865237/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93318
Approved by: https://github.com/rohan-varma
2023-02-03 00:50:04 +00:00
Chien-Chin Huang
e32d99ae19 [FSDP][optim_state_dict] Make FSDP.optim_state_dict compatbile with DMP (#93285)
`torchrec.DistributedModelParallel` overwrites `named_parameters` and is not compatible with `FullyShardedDataParallel`'s optim_state_dict. This PR adds some workaround in `FullyShardedDataParallel` to make both work together.

Differential Revision: [D42764611](https://our.internmc.facebook.com/intern/diff/D42764611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93285
Approved by: https://github.com/rohan-varma
2023-02-02 23:42:54 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Chien-Chin Huang
888771dc5d [FSDP][optim_state_dict] Fix _is_named_optimizer when the state is empty (#93303)
Optimizer state is not eager initializaion -- only NamedOptimizer and KeyedOptimizer are. This PR makes it `_is_named_optimizer` work with regular optimizers.

Differential Revision: [D42858589](https://our.internmc.facebook.com/intern/diff/D42858589/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93303
Approved by: https://github.com/fduwjj
2023-01-31 03:49:26 +00:00
Chien-Chin Huang
a4238976a8 [FSDP][optim_state_dict] Ensure correct devices for tensors when doing all_gather (#92992)
When doing `_all_gather_optim_state`, we need to ensure that `step` tensors are  on CPU and other tensors are on GPUs. This PR add the logic to ensure the locality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92992
Approved by: https://github.com/fduwjj
2023-01-27 06:50:36 +00:00