Commit Graph

33 Commits

Author SHA1 Message Date
Andrew Gu
d087b32149 [BE][FSDP] Retire _get_full_detached_param() (#80871)
The tests did not actually require that the parameters be detached, so this coalesces `_get_full_detached_param()` with `get_full_params()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80871
Approved by: https://github.com/rohan-varma
2022-07-08 22:28:16 +00:00
Andrew Gu
2ea215fd59 [BE][FSDP] Sort common_fsdp.py imports (#80870)
This was part of my initial attempt to make the PRs smaller, but evidently, I failed 😅 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80870
Approved by: https://github.com/rohan-varma
2022-07-08 22:27:57 +00:00
Linjian Ma
70446c25d7 [FSDP] Add forward prefetching option in FSDP API (#78841)
Fixes #78608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78841
Approved by: https://github.com/zhaojuanmao
2022-06-15 20:59:08 +00:00
lezcano
ff7b6d6b5f Update linalg.*norm
This PR does a number of things:
- Move linalg.vector_norm to structured kernels and simplify the logic
- Fixes a number of prexisting issues with the dtype kwarg of these ops
- Heavily simplifies and corrects the logic of `linalg.matrix_norm` and `linalg.norm` to be consistent with the docs
  - Before the `_out` versions of these functions were incorrect
  - Their implementation is now as efficient as expected, as it avoids reimplementing these operations whenever possible.
- Deprecates `torch.frobenius_norm` and `torch.nuclear_norm`, as they were exposed in the API and they are apparently being used in mobile (??!!) even though they were not documented and their implementation was slow.
  - I'd love to get rid of these functions already, but I guess we have to go through their deprecation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76547

Approved by: https://github.com/mruberry
2022-05-18 11:46:50 +00:00
Rohan Varma
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
Sisil Mehta
9d3ffed327 [FSDP] Sharded Grad Scaler (#76918)
Summary: Adding in a shard aware grad scaler for FSDP+MixedPrecision support

Test Plan: Tests added

Differential Revision: D35988676

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76918
Approved by: https://github.com/rohan-varma
2022-05-16 15:53:21 +00:00
Rohan Varma
b30c027abf Fix FSDP CI
Sometimes we randomly see unrelated FSDP CI failures such as https://github.com/pytorch/pytorch/runs/6298275361?check_suite_focus=true which are unrelated to the diff at hand. Suspicion is that because some other tests set `BACKEND` which is a generic env var for distributed tests, if those tests are run in same CI container before, they won't get unset and we'll use gloo for FSDP backend.

But gloo is not currently supported, and this was mostly added for easy testing during early FSDP development, so remove this entirely.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76878
Approved by: https://github.com/awgu
2022-05-05 13:49:24 +00:00
Andrew Gu
648823b087 [FSDP] Add ignored_modules ctor arg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75431

Approved by: https://github.com/rohan-varma
2022-04-12 19:46:00 +00:00
Rohan Varma
143f7cca5d [FSDP] summon full params staticmethod
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75423

Make summon_full_params a static method. We still retain the old summon_full_params as a private API `_summon_full_params` and there are a couple of callsites to this within only FSDP file, but we can remove these as well.

Differential Revision: [D35444539](https://our.internmc.facebook.com/intern/diff/D35444539/)

Approved by: https://github.com/awgu
2022-04-12 01:25:10 +00:00
Rohan Varma
3a0b393d49 Back out "Revert D35000703: [WIP][FSDP] Mixed precision enablement" (#75024)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75024

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703 (6b0b088c6c)
ghstack-source-id: 153059190

(Note: this ignores all push blocking failures!)

Test Plan: CI

Reviewed By: pbelevich

Differential Revision: D35287501

fbshipit-source-id: c6c9ada039de27cf9cc477561f92a7f888bdf5f7
(cherry picked from commit a450c7ad75507a8ac637907b51217986d0141dc0)
2022-04-05 21:15:57 +00:00
Chien-Chin Huang
fd4ad5d72c [FSDP] Register state_dict hooks for FlatParamsWrapper even if params_list is empty (#74860)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74860

These pre/post hooks must be registered even if the FlatParamsWrapper does not flatten any parameters; any submodule inside FlatParamsWrapper should be pre/post processed by the hooks.
ghstack-source-id: 152594052

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D35194483

fbshipit-source-id: c25d7846f317c7ce78d77d335d041fed8db8f3a1
(cherry picked from commit db2cc311714e579362f5201922be715a626d48df)
2022-03-31 22:06:45 +00:00
Nikita Shulga
a98d1a5ff4 Revert D35000703: [WIP][FSDP] Mixed precision enablement
Test Plan: revert-hammer

Differential Revision:
D35000703 (6b0b088c6c)

Original commit changeset: 4bd7937ff36b

Original Phabricator Diff: D35000703 (6b0b088c6c)

fbshipit-source-id: 99295ea4ff022dea22b89d9d965ea4261cdf8826
(cherry picked from commit 05ed48197d652a911cdd040aea0fc67768ef10e5)
2022-03-31 16:28:17 +00:00
Rohan Varma
6b0b088c6c [WIP][FSDP] Mixed precision enablement (#74452)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74452

Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.

Test coverage:
[ ] Test1
ghstack-source-id: 152654758

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D35000703

fbshipit-source-id: 4bd7937ff36bdb3afd60eda981afc9d8731b823a
(cherry picked from commit 6ed6721aaf18f323656686200465fc78cef1d0dd)
2022-03-31 14:17:02 +00:00
Andrew Gu
522041a0fd [FSDP] Add full optim state dict (#74215)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74215

###  Overview of API
This PR introduces full optimizer state dict checkpointing.
- This allows users to save the optimizer state for a `torch.nn.Module` (not necessarily a `FullyShardedDataParallel` instance) that contains `FullyShardedDataParallel` instances and later load that optimizer state.
- This supports loading to a module with a different world size, but the `FSDP` wrapping scheme must be the same.

To **save** the optimizer state, run the following (on all ranks):
```
model: torch.nn.Module = ...
optim = torch.optim.Adam(model.parameters(), ...)
# Train for some steps...
full_osd = FSDP.full_optim_state_dict(model, optim)  # returns non-empty dict only on rank 0
if rank == 0:
    torch.save(full_osd, ...)
```
To **load** the optimizer state, run the following (on all ranks):
```
new_model: torch.nn.Module = ...  # may use different world size
full_osd = torch.load(...)
sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
optim = torch.optim.Adam(new_model.parameters(), ...)
optim.load_state_dict(sharded_osd)
```

To support **multiple parameter groups**, we require using an additional argument `optim_input`, which is the first argument that the user passes into the optimizer constructor.
```
optim_input = ...
optim = torch.optim.Adam(optim_input, ...)
FSDP.full_optim_state_dict(model, optim, optim_input)  # one more argument
...
new_optim_input = ...
new_optim = torch.optim.Adam(new_optim_input, ...)
FSDP.shard_full_optim_state_dict(full_osd, new_model, new_optim_input)  # one more argument
```
One caveat is that the user should be careful of generators, which are exhausted after their first use. The `optim_input` passed into the `FSDP` APIs should be refreshed version of the generator if using generators.

### Test Plan
**`full_optim_state_dict()`**
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is `model.parameters()`.
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is multiple parameter groups (changing parameter order).

**`shard_full_optim_state_dict()`**
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is `model.parameters()`.
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is multiple parameter groups (changing parameter order).
- [x] `shard_full_optim_state_dict()` raises a `ValueError` when changing the `FSDP` wrapping scheme.

On the AWS cluster, the TTS contribution for these tests is ~45 seconds.

###  Developer Notes
**Relaxing the Problem**
For optimizer state checkpointing, we have relaxed the problem to **not support changing the `FSDP` wrapping scheme** between save and load time. It is unclear how to solve without this relaxation. This was the least restrictive way to relax the problem since it does not affect most expected use cases. Rather, the expected change between save and load time is the **world size**, which this implementation **does support**.

Even with the relaxation, the `optim_input` argument is necessary to determine the `flat_param_id_to_param` mapping, which is important to know which parameter IDs in the flattened space correspond to `FlatParameter`s that hence need to be unflattened.

**Differences with Local Equivalent**
Suppose `full_osd = full_optim_state_dict()` and `local_osd = state_dict()` for a purely local equivalent. The difference between `full_osd` and `local_osd` is that the parameter IDs of unflattened parameters comprising a single flattened parameter are always consecutive in `full_osd`, while they may be non-consecutive in `local_osd`. Suppose in the following that each layer has 1 parameter `param`:
```
FSDP(model)
    layer1
    FSDP(layer2)
    layer3
```
`layer1.param` and `layer3.param` are flattened and attributed to `model`. `layer2.param` is flattened and attributed to itself.
- In `local_osd`, the parameter IDs would be `0: layer1.param`, `1: layer2.param`, and `2: layer3.param`.
- In `full_osd`, the parameter IDs would be `0: layer1.param`, `1: layer3.param`, and `2: layer2.param`. (Parameter IDs of unflattened parameters sharing a flattened parameter are consecutive.)

The idea is that as long as `full_optim_state_dict()` and `shard_full_optim_state_dict()` are internally consistent, then there is no need to match the local equivalent (assuming no change in `FSDP` wrapping).

### Follow-Ups
**API**
- If needed, we can follow-up this PR by adding an argument `key_by_name: bool = False` to both methods that may be set to `True` to key parameters by `str` names instead of `int` parameter IDs. We still need to investigate if keying by name enables changing the `FSDP` wrapping scheme.

**Refactoring**
- In this optimizer state checkpointing, all optimizer state is saved to CPU on rank 0 (set as `OPTIM_TARGET_RANK`). We should unify and refactor these assumptions with model state checkpointing.

**Testing**
- The code path for unused parameters is not tested. The testing and any needed implementation fixes can be done in a follow-up.
- The code path for non-tensor states (e.g. `Adam` `"step"` as `float` instead of as zero-dimension `FloatTensor`) is not tested. However, it is identical to that of zero-dimension tensor states, so I have some confidence. If needed, I can add tests for it in a follow-up.
    - Would I have to write my own optimizer? I do not want to introduce dependencies on third party libraries like Nvidia `apex`.
- We may want to add end-to-end checkpointing tests that include both model state dict and optimizer state dict.

Test Plan: Imported from OSS

Reviewed By: zhaojuanmao

Differential Revision: D35045121

Pulled By: awgu

fbshipit-source-id: 33c650dc960acbd7613d4f444a852b9f76ca4a9b
(cherry picked from commit 2bbc2e344296dc455cf686f3a9b097989504be81)
2022-03-30 14:15:23 +00:00
Yanli Zhao
5e39d94908 make sharding strategy configurable and support zero2 algorithm (#73819)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73819

adding a new sharding_strategy config in FSDP API to support different data parallel algorithm. also add support for zero2 algorithm, which will only shard optimizer states and grads
ghstack-source-id: 151454460

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D34662583

fbshipit-source-id: 14c6e0c0054692ecd76512c025d60deb4964ec5f
(cherry picked from commit 51382e882447b4756c4ee6d94ce0939a25955b00)
2022-03-16 17:21:41 +00:00
Junjie Wang (PyTorch)
616b36e437 [PT-D][FSDP] Implement _clip_grad_norm_ for FSDP (#73405)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73405

Implement the `_clip_grad_norm_` for FSDP, issue: https://github.com/pytorch/pytorch/issues/72548
ghstack-source-id: 151059433

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D34230605

fbshipit-source-id: bbac7a6e49276e0f0502e2f4466c984aee2629fa
(cherry picked from commit f10d090cd11489608ab3f67f52e3e950cd9f7dea)
2022-03-11 00:41:07 +00:00
Andrew Gu
4a06b8d36c [FSDP] Add grad accumulation without no_sync() (#73535)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73535

**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```

NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.

**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.

After this diff:
- QPS (ResNet): f328439897
- QPS (RoBERTa): f328440141
- Accuracy: f328442119

Before this diff (trunk):
- QPS (ResNet): f328432756
- QPS (RoBERTa): f328436766
- Accuracy: f328437896

Test Plan: Imported from OSS

Reviewed By: zhaojuanmao

Differential Revision: D34533546

Pulled By: awgu

fbshipit-source-id: 821d762dfad5f2b1e59adcb8e5cb7c277399040c
(cherry picked from commit 746a5ea2720dcf87c376229b405a318396fe5769)
2022-03-07 20:33:22 +00:00
Chien-Chin Huang
6396547f9e [FSDP] Make summon_full_params a public method (#73116)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73116

Users may need summon_full_params() to get the original parameters.
ghstack-source-id: 150134237

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D34353034

fbshipit-source-id: ac69cc032da177903cd9969094f3f82dc6a61636
(cherry picked from commit 55d34fdee3778110a165a13ae987d0339e8d33c7)
2022-03-01 22:29:28 +00:00
Rohan Varma
6b424de338 [FSDP] Add state_dict() save/reload in parity test (#73366)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73366

Adds state_dict() save/reload in parity with DDP test to ensure
checkpointing doesn't cause issue with accuracy/model params.
ghstack-source-id: 150114251

Test Plan: CI

Reviewed By: fegin

Differential Revision: D34434358

fbshipit-source-id: fb0787486b383cfcbec7cc1325a486c8d9b1e2ea
(cherry picked from commit e3bcc7733cb5a497a640007044b1138dfee3a532)
2022-03-01 04:35:30 +00:00
Rohan Varma
540361fa53 [FSDP] full_state_dict impl (#73324)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73324

Implements `state_dict` and `load_state_dict` APIs for FSDP, with the following limitations:

1. Does not support `state_dict_device` (i.e. specifying which device params should be on) which fairscale does currently support
2. Does not yet support offload of state_dict onto CPU
3. Loads state_dict on all ranks currently. In the future we could add support for loading this on only rank 0, to avoid redundancy across ranks as usually only one rank is responsible for saving/loading the model. Along with (2) this would enable larger models to have state_dict called.

As discussed in FSDP checkpoint API proposal, `state_dict` will basically be a `full_state_dict` where full parameters are returned on all ranks. As a result this implies that the model must actually be able to fit on a single GPU.
ghstack-source-id: 150012240

Test Plan: ci

Reviewed By: zhaojuanmao

Differential Revision: D34433514

fbshipit-source-id: 3eb1d679b2236264f9f423e761d1720f9aaec73a
(cherry picked from commit a451d5a08ebfa14a229a25fea35b9ca59fe91a59)
2022-02-27 19:32:22 +00:00
Rohan Varma
199d1cb9dd [FSDP][BE] remove get_full_params() from test code (#73242)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73242

Can use summon_full_params instead.
ghstack-source-id: 149800364

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34399789

fbshipit-source-id: 8552cdf3ed003aba1316f554f4ec457fdada5dbe
(cherry picked from commit a397e2dfd3750afe1d21cdee3aa4c2d525ed837e)
2022-02-24 19:39:32 +00:00
Rohan Varma
e10cd88648 [FSDP] summon_full_params fix (#73314)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73314

Needs to synchronize all_gather stream. Added test fails without this
fix
ghstack-source-id: 149800363

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34430602

fbshipit-source-id: 4ce07e2d098a4f07ac640285db1d0ff64fd42232
(cherry picked from commit 24c756e7bba69017b9358bf824589b2aeb366b5e)
2022-02-24 19:39:32 +00:00
Yanli Zhao
2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
2022-02-02 15:50:14 +00:00
Rohan Varma
d0ff1f0013 [FSDP] Backward prefetch in recursive call (#71804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804

Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33782346

fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c)
2022-01-28 00:34:08 +00:00
Yanli Zhao
b15212c62b enable backward pass computation and communication overlap by prefetching all gather (#70235)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235

address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.

After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.

So putting  these two algorithms as two configurable experimental algorithms for now

prefetch full parameters at pre-backward hook:

It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.

To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.

The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.

In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.

Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33252795

fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
2021-12-22 23:02:46 -08:00
Rohan Varma
c4281cc92d Prototype checkpoint_wrapper (#69955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955

Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.

Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.

Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D33107276

fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
2021-12-16 09:59:19 -08:00
Rohan Varma
7fad758e02 [FSDP] AutoWrap Main API (#68155)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68155

Per title
ghstack-source-id: 144398229

Test Plan: CI

Reviewed By: pbelevich, mrshenli

Differential Revision: D32327954

fbshipit-source-id: 36bdf06c1c50932a93acbfa97017c549fa490a6c
2021-12-01 00:16:38 -08:00
Yanli Zhao
f6696c5a85 export CPUOffload in _fsdp package (#68308)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68308

export CPUOffload in _fsdp package, as cpu_offload config in FSDP API needs to import this class
ghstack-source-id: 143560608

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D32408719

fbshipit-source-id: ee5c40ec91a423fbd58872fbdeb5f2dda8a3d89e
2021-11-16 22:56:12 -08:00
Rohan Varma
ace2183195 [FSDP] Address follow up comments for CPU offload (#67813)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67813

Address Shen's comments in
https://github.com/pytorch/pytorch/pull/67249/files
ghstack-source-id: 142379312

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32157545

fbshipit-source-id: 3cc2df6d5fa0d3b9383ed3711e7f79729dbb1dda
2021-11-05 10:34:08 -07:00
Rohan Varma
fd77fff0b1 [FSDP] customizable backend in test (#67135)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67135

Add ability to use env var backend for quicker testing (and gloo2 in
the future)
ghstack-source-id: 142274304

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31878285

fbshipit-source-id: 80ae7107cd631a1a15ebc23262b27d8192cfe4b6
2021-11-03 15:45:52 -07:00
Rohan Varma
7f3326a6d2 [FSDP] CPU offload resubmit (#67249)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249

Implements CPU offload for model parameters in FSDP.

- CPU offload class with only offload_params attribute is created
- If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__
- In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here.
- In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device.
- The same logic as above applies to the beginning of backwards pass.
- At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU.

Regarding tests:
- We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model.
- Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment.
- Verifies all params are offloaded to CPU after init.
- Verifies all params are offloaded to CPU after forward and backward.
- Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961
ghstack-source-id: 141851903

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31911085

fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7
2021-11-02 23:27:34 -07:00
Sisil Mehta
5ad169b7cc Adding in Wrap functions for FSDP from Fairscale (#67292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67292

as title

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed/fsdp:wrap --keep-going

Reviewed By: rohan-varma

Differential Revision: D31936404

fbshipit-source-id: b7ebead9a649766aec83e5630c2ce1386ad33e11
2021-11-02 13:30:41 -07:00
Yanli Zhao
df3f82a1ef Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrapper (#66904)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66904

Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrappe, these unit tests are refactored to be aligned with PyTorch commonly used test classes
ghstack-source-id: 141335614

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D31779565

fbshipit-source-id: c727110d1d7570c0ec49e42cadfc9e9a5e440073
2021-10-22 16:50:52 -07:00