Commit Graph

11 Commits

Author SHA1 Message Date
Yanli Zhao
2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
2022-02-02 15:50:14 +00:00
Rohan Varma
d0ff1f0013 [FSDP] Backward prefetch in recursive call (#71804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804

Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33782346

fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c)
2022-01-28 00:34:08 +00:00
Yanli Zhao
b15212c62b enable backward pass computation and communication overlap by prefetching all gather (#70235)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235

address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.

After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.

So putting  these two algorithms as two configurable experimental algorithms for now

prefetch full parameters at pre-backward hook:

It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.

To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.

The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.

In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.

Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33252795

fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
2021-12-22 23:02:46 -08:00
Rohan Varma
c4281cc92d Prototype checkpoint_wrapper (#69955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955

Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.

Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.

Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D33107276

fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
2021-12-16 09:59:19 -08:00
Rohan Varma
7fad758e02 [FSDP] AutoWrap Main API (#68155)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68155

Per title
ghstack-source-id: 144398229

Test Plan: CI

Reviewed By: pbelevich, mrshenli

Differential Revision: D32327954

fbshipit-source-id: 36bdf06c1c50932a93acbfa97017c549fa490a6c
2021-12-01 00:16:38 -08:00
Yanli Zhao
f6696c5a85 export CPUOffload in _fsdp package (#68308)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68308

export CPUOffload in _fsdp package, as cpu_offload config in FSDP API needs to import this class
ghstack-source-id: 143560608

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D32408719

fbshipit-source-id: ee5c40ec91a423fbd58872fbdeb5f2dda8a3d89e
2021-11-16 22:56:12 -08:00
Rohan Varma
ace2183195 [FSDP] Address follow up comments for CPU offload (#67813)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67813

Address Shen's comments in
https://github.com/pytorch/pytorch/pull/67249/files
ghstack-source-id: 142379312

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32157545

fbshipit-source-id: 3cc2df6d5fa0d3b9383ed3711e7f79729dbb1dda
2021-11-05 10:34:08 -07:00
Rohan Varma
fd77fff0b1 [FSDP] customizable backend in test (#67135)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67135

Add ability to use env var backend for quicker testing (and gloo2 in
the future)
ghstack-source-id: 142274304

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31878285

fbshipit-source-id: 80ae7107cd631a1a15ebc23262b27d8192cfe4b6
2021-11-03 15:45:52 -07:00
Rohan Varma
7f3326a6d2 [FSDP] CPU offload resubmit (#67249)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249

Implements CPU offload for model parameters in FSDP.

- CPU offload class with only offload_params attribute is created
- If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__
- In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here.
- In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device.
- The same logic as above applies to the beginning of backwards pass.
- At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU.

Regarding tests:
- We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model.
- Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment.
- Verifies all params are offloaded to CPU after init.
- Verifies all params are offloaded to CPU after forward and backward.
- Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961
ghstack-source-id: 141851903

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31911085

fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7
2021-11-02 23:27:34 -07:00
Sisil Mehta
5ad169b7cc Adding in Wrap functions for FSDP from Fairscale (#67292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67292

as title

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed/fsdp:wrap --keep-going

Reviewed By: rohan-varma

Differential Revision: D31936404

fbshipit-source-id: b7ebead9a649766aec83e5630c2ce1386ad33e11
2021-11-02 13:30:41 -07:00
Yanli Zhao
df3f82a1ef Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrapper (#66904)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66904

Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrappe, these unit tests are refactored to be aligned with PyTorch commonly used test classes
ghstack-source-id: 141335614

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D31779565

fbshipit-source-id: c727110d1d7570c0ec49e42cadfc9e9a5e440073
2021-10-22 16:50:52 -07:00