Commit Graph

11 Commits

Author SHA1 Message Date
Andrew Gu
5652ab22f6 [FSDP] Add _set_flattened(); _is_flattened() (#85038)
For both exposing the original parameters and for TP integration, we cannot only rely on `isinstance(param, FlatParameter)` to ignore already-flattened parameters in `.named_parameters()`. As a simple workaround, we can mark original parameters or `ShardedTensor`s with an attribute `_fsdp_flattened` (saved as a string variable `FSDP_FLATTENED`) to indicate that the parameter/tensor has already been flattened. This issue only arises for recursive/nested FSDP wrapping.

This PR also changes `isinstance(param, FlatParameter)` checks to `type(param) is FlatParameter` because all tensor subclasses that have `_is_param == True` will return `True` for `isinstance(param, <any subclass with _is_param == True>)`. This means that a `ShardedTensor` parameter will return `True` for `isinstance(st, FlatParameter)`, which is not what we want.
5271494ef2/torch/nn/parameter.py (L8-L10)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85038
Approved by: https://github.com/rohan-varma
2022-09-16 03:45:29 +00:00
Andrew Gu
88802719b6 [FSDP][Easy] Move utils to _utils.py (#84212)
I pulled this out into a separate PR. This just moves some utility functions to `_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84212
Approved by: https://github.com/rohan-varma
2022-09-01 19:27:51 +00:00
Rohan Varma
9690fbf9a8 FSDP namedtuple support (#83055)
- NamedTuple support is blocking MultiModal adoption. TODO: add test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83055
Approved by: https://github.com/awgu
2022-08-10 16:44:37 +00:00
edward-io
e7ff9d44ad [fsdp] add ability to iterate through dataclasses in fsdp.utils (#82638)
### Description

previously FSDP was failing on a torchmultimodal model because `_apply_to_tensors` couldn't iterate over dataclasses.

### Issue

None

### Testing

unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82638
Approved by: https://github.com/rohan-varma
2022-08-05 18:34:31 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rohan Varma
9493900876 [Reland] Mixed precision batchnorm fix (#77234)
Reland of https://github.com/pytorch/pytorch/pull/77089, which was reverted due to land race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77234
Approved by: https://github.com/zhaojuanmao
2022-05-11 15:03:34 +00:00
PyTorch MergeBot
091f8915ae Revert "Mixed Precision batchnorm fix (#77089)"
This reverts commit bf61b79503.

Reverted https://github.com/pytorch/pytorch/pull/77089 on behalf of https://github.com/suo
2022-05-11 03:00:33 +00:00
Rohan Varma
bf61b79503 Mixed Precision batchnorm fix (#77089)
Rehash of https://github.com/pytorch/pytorch/pull/76642 which could not be updated due to GHF out of sync issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77089
Approved by: https://github.com/awgu
2022-05-11 02:22:01 +00:00
Andrew Gu
73b33de989 [FSDP] Include buffers in ignored_modules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76784

Approved by: https://github.com/rohan-varma
2022-05-05 12:29:59 +00:00
lkct
9fae0762b0 fix typing in Module.state_dict and load_state_dict
Fixes #72707

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73483
Approved by: https://github.com/albanD, https://github.com/jbschlosser
2022-05-02 17:27:54 +00:00
yanlizhao
887a93e5ac support PackedSequence type for apply_for_tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76265

support PackedSequence type for apply_for_tensors, some rnn modules outputs are PackedSequence types

Differential Revision: [D35862156](https://our.internmc.facebook.com/intern/diff/D35862156/)

Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-04-26 22:03:25 +00:00