Commit Graph

19 Commits

Author SHA1 Message Date
Rohan Varma
c43e88665a [Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
2023-02-17 18:24:20 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
Chien-Chin Huang
4de50b2521 [FSDP] Allow to use TorchDispatch with FSDP (#88014)
Add `_no_dispatch_record_stream` to disable TorchDispatch before calling `record_stream()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88014
Approved by: https://github.com/awgu
2022-11-03 23:15:56 +00:00
Andrew Gu
d6b58d6924 [FSDP()][23/N] Refactor handle attr initialization (#87938)
**`_init_param_attributes()` -> `init_flat_param_attributes()`**
We move `_init_param_attributes()` to `FlatParamHandle.init_flat_param_attributes()` (as already marked as to-do during previous refactoring).

**`_reset_lazy_init()`**
We no longer delete `_local_shard` from each `FlatParameter` in `_reset_lazy_init()`.

**Analysis**
Thus, the two semantic differences are that we remove the initial `if hasattr(p, "_local_shard")` early return in `_init_param_attributes()` and the `delattr(p, "_local_shard")` in `_reset_lazy_init()`.

This is safe because
- If we never call `_reset_lazy_init()`, then `init_flat_param_attributes()` is only called once. There is no opportunity for an early return.
- If we call `_reset_lazy_init()`, then `init_flat_param_attributes()` will be called again in the next `_lazy_init()`. However, since we removed the early return, all of the attributes initialized in `init_flat_param_attributes()` simply get re-initialized and override any existing attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87938
Approved by: https://github.com/mrshenli
2022-11-02 11:32:56 +00:00
Andrew Gu
b1750d0440 [FSDP()][13/N] Refactor unshard/reshard/grads (#87926)
This PR is not too complicated. We just move unshard/reshard/grads out to `_runtime_utils.py` and make them take `state: _State` instead of `self`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87926
Approved by: https://github.com/mrshenli
2022-11-01 13:37:31 +00:00
Andrew Gu
d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00
Andrew Gu
e3cf81e0a7 [FSDP] ufmt /fsdp (#87811)
This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-10-27 04:25:55 +00:00
Andrew Gu
be682befbc [FSDP] Add use_orig_params (#84911)
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84911
Approved by: https://github.com/rohan-varma
2022-10-07 18:07:17 +00:00
Andrew Gu
5652ab22f6 [FSDP] Add _set_flattened(); _is_flattened() (#85038)
For both exposing the original parameters and for TP integration, we cannot only rely on `isinstance(param, FlatParameter)` to ignore already-flattened parameters in `.named_parameters()`. As a simple workaround, we can mark original parameters or `ShardedTensor`s with an attribute `_fsdp_flattened` (saved as a string variable `FSDP_FLATTENED`) to indicate that the parameter/tensor has already been flattened. This issue only arises for recursive/nested FSDP wrapping.

This PR also changes `isinstance(param, FlatParameter)` checks to `type(param) is FlatParameter` because all tensor subclasses that have `_is_param == True` will return `True` for `isinstance(param, <any subclass with _is_param == True>)`. This means that a `ShardedTensor` parameter will return `True` for `isinstance(st, FlatParameter)`, which is not what we want.
5271494ef2/torch/nn/parameter.py (L8-L10)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85038
Approved by: https://github.com/rohan-varma
2022-09-16 03:45:29 +00:00
Andrew Gu
88802719b6 [FSDP][Easy] Move utils to _utils.py (#84212)
I pulled this out into a separate PR. This just moves some utility functions to `_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84212
Approved by: https://github.com/rohan-varma
2022-09-01 19:27:51 +00:00
Rohan Varma
9690fbf9a8 FSDP namedtuple support (#83055)
- NamedTuple support is blocking MultiModal adoption. TODO: add test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83055
Approved by: https://github.com/awgu
2022-08-10 16:44:37 +00:00
edward-io
e7ff9d44ad [fsdp] add ability to iterate through dataclasses in fsdp.utils (#82638)
### Description

previously FSDP was failing on a torchmultimodal model because `_apply_to_tensors` couldn't iterate over dataclasses.

### Issue

None

### Testing

unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82638
Approved by: https://github.com/rohan-varma
2022-08-05 18:34:31 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rohan Varma
9493900876 [Reland] Mixed precision batchnorm fix (#77234)
Reland of https://github.com/pytorch/pytorch/pull/77089, which was reverted due to land race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77234
Approved by: https://github.com/zhaojuanmao
2022-05-11 15:03:34 +00:00
PyTorch MergeBot
091f8915ae Revert "Mixed Precision batchnorm fix (#77089)"
This reverts commit bf61b79503.

Reverted https://github.com/pytorch/pytorch/pull/77089 on behalf of https://github.com/suo
2022-05-11 03:00:33 +00:00
Rohan Varma
bf61b79503 Mixed Precision batchnorm fix (#77089)
Rehash of https://github.com/pytorch/pytorch/pull/76642 which could not be updated due to GHF out of sync issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77089
Approved by: https://github.com/awgu
2022-05-11 02:22:01 +00:00
Andrew Gu
73b33de989 [FSDP] Include buffers in ignored_modules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76784

Approved by: https://github.com/rohan-varma
2022-05-05 12:29:59 +00:00
lkct
9fae0762b0 fix typing in Module.state_dict and load_state_dict
Fixes #72707

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73483
Approved by: https://github.com/albanD, https://github.com/jbschlosser
2022-05-02 17:27:54 +00:00
yanlizhao
887a93e5ac support PackedSequence type for apply_for_tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76265

support PackedSequence type for apply_for_tensors, some rnn modules outputs are PackedSequence types

Differential Revision: [D35862156](https://our.internmc.facebook.com/intern/diff/D35862156/)

Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-04-26 22:03:25 +00:00