Fixes#111499
This PR slightly alters the new fused `all_gather` `optim_state_dict` implementation to support optimizers without tensor state (e.g. SGD) in a `use_orig_params=True` context.
The principle change is to short-circuit `_allgather_orig_param_states` if an empty `state_buffers` dict is returned after completing `_convert_all_state_info` here:
93e5065ba0/torch/distributed/fsdp/_optim_utils.py (L1481-L1484)
To allow `_convert_all_state_info` to accommodate optimizers with no tensor state, I also change the scope of `dtype` and make the return type `Optional`.
As discussed in the issue this PR fixes, I'm [extending](93e5065ba0/test/distributed/fsdp/test_fsdp_optim_state.py (L1915I)) `test_state_dict_with_none_tensor_state` to test with both Adam and SGD optimizers to validate scalar and non-tensor states continue to be restored for both optimizer types.
Thanks to the distributed team as always for their adroit design and exceptionally valuable contributions to the open source ML community. Hope you all feel appreciated commensurate with the compounding progress your work enables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111501
Approved by: https://github.com/fegin
We had the option but never used cpu_offload as optimizer state_dict offloads the tensors to CPU by default. And this is usually most users want as the tensors are required to be moved to CPU eventually. However, we may want to disable offloading to CPU in some cases, epsecially for the debugging purpose. This PR lets optimizer state_dict read the flag.
Differential Revision: [D48913340](https://our.internmc.facebook.com/intern/diff/D48913340/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108434
Approved by: https://github.com/wz337
The original implementation of `_gather_orig_param_state` is naive. It performs one allgather_object and two allgather (if the optimizer is Adam) per FQN. This can be slow and make `_optim_state_dict` become bottleneck.
This PR rewrite the implementation and fuse all the `allgather_object`s into one. As for `allgather`, it is fused based on the information of FlatParameters. So there will be 2N `allgather` where N is the number of FlatParameter and 2 is due to Adam having 2 states per FQN.
One experiment on 8GPU A100 shows that the execution of the gathering is improved to 0.3 seconds from 3 seconds.
Differential Revision: [D48835138](https://our.internmc.facebook.com/intern/diff/D48835138/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108298
Approved by: https://github.com/awgu
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.
This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.
Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.
Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).
- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
There are many communication operations for shardedTensor in the state dict of fsdp. They use the external passed-in pg (or the default pg), which currently supports cuda devices. Before communication, the memory will be moved to cuda, which is implicit (because it is essentially moving data to the memory type required by pg, not the computing device type). Similarly, when users use fsdp on a custom backend, they will pass in a custom pg (which does not support cuda devices), which may cause fsdp to not work properly in some cases. This PR obtains the memory type supported by the pg through _get_pg_default_device during communication, and moves the data to it when needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101533
Approved by: https://github.com/awgu
When use_orig_param is True and sharding is NO_SHARD, parameters and states are not flattened, so optimizer states should not be flattened as well. The unit test will fail without the fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100189
Approved by: https://github.com/awgu
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.
This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.
The main issues addressed are:
#### 1. Device decision for FSDP wrapping of Modules without Parameters
Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.
#### 2. Abstraction of a cuda-like device
Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).
The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.
Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
This PR makes `use_orig_params=True` case support rank0_only loading for optim state_dict. The implementation is different from `use_orig_params=False`. The `use_orig_params=False` implementation first flatten the parameters on rank0 and then broadcast the states while this implementation broadcast the state when doing the flattening. The implementation is slower as it broadcast the original parameters instead of the flattened ones. However, the implementation introduced by this PR is simpler. As loading is usually happen once per training life, the performance difference can be ignored. In next PR, we will consolidate the implementations in favor of the simpleness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99624
Approved by: https://github.com/wz337
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:
Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)
Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.
If you believe this diff has been generated in error you may Commandeer and Abandon it.
Test Plan: NA
Reviewed By: fegin
Differential Revision: D45027286
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).
Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner.
**Details**
Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors.
- `_numels_with_padding`: length `N + M`
- `_is_padding_mask`: length `N + M`
- `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N`
`_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter.
This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now.
For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796
Approved by: https://github.com/rohan-varma
This PR adds intra-`FlatParameter` 16-byte alignment padding to the `use_orig_params=True` code path to avoid clones in TorchInductor.
**Approach**
The `FlatParameter` maintains several data structures about its original parameters. Notably, the data structures `_param_infos`, `_shapes`, `_numels`, and `_fqns` have the same length and index in the same way.
This PR treats alignment padding _like_ an original parameter in that the padding gets flattened into the `FlatParameter`. Therefore, it must be reflected in the aforementioned data structures. However, given the way in which the data structures are used, we choose to do the following if the `i`th tensor flattened into the `FlatParameter` is padding:
- `_numels[i]` is the numel of padding
- `_param_infos[i] == _shapes[i] == _fqns[i] == None`
This choice is because (1) we must record the padding numel to account for it (e.g. for views) and (2) we prefer to preserve the invariant that the data structures index in the same way over avoiding `None` entries.
To ease the burden of other FSDP developers, we separate the parameter flattening logic:
- `_init_flat_param_and_metadata()`: This should be called only once in the `FlatParamHandle` constructor. The `FlatParameter` metadata is assumed to be static thereafter.
- `flatten_tensors()` / `flatten_tensors_into_flat_param()`: These can be used for optimizer and model state dict and can be called after construction time.
This separation allows `_init_flat_param_and_metadata()` to contain the much heavier metadata logic, while keeping the latter methods to be much lighter. The only constraint is that the alignment padding logic must be kept consistent between the two, but this should be worth the simper interface.
**Testing**
- This PR directly modifies the `use_orig_params=True` code path, so all existing tests passing gives good signal.
- Some existing unit tests had to be adjusted to account for the alignment padding.
- This PR adds two tests in `test_fsdp_flatten_params.py` to explicitly test the sharding metadata with alignment for both parameter full precision and mixed precision since the latter requires possibly more padding elements due to the decreased per-element size.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97667
Approved by: https://github.com/rohan-varma
This is an easy PR. It has some remaining local changes that I had that I felt clarified naming.
- `_param_fqns` -> `_param_name_infos` since it returns a tuple of `fqn, param_name, module_name`, not only `fqn`. (similarly for `_shared_param_fqns` -> `_shared_param_name_infos`)
- nit: `parameter_module_names` -> `param_module_names` for consistency since we almost never fully spell out `parameter`. (similarly for `shared_parameter_module_names` -> `shared_param_module_names`)
- nit: `full_fqn` -> `fqn_from_global_root`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97666
Approved by: https://github.com/rohan-varma
Summary: When parameters are flattening, multiple parameters share the same step. When unflattening the parameters, current implementation still make these parameters share the same step. When this is not wrong, some training infra get confused by sharing tensor storages. This PR fixes the issue.
Test Plan: CI
Reviewed By: awgu
Differential Revision: D43893592
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96313
Approved by: https://github.com/zhaojuanmao