Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235
address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.
After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.
So putting these two algorithms as two configurable experimental algorithms for now
prefetch full parameters at pre-backward hook:
It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.
To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.
The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.
In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.
Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D33252795
fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955
Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.
Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.
Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D33107276
fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68308
export CPUOffload in _fsdp package, as cpu_offload config in FSDP API needs to import this class
ghstack-source-id: 143560608
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D32408719
fbshipit-source-id: ee5c40ec91a423fbd58872fbdeb5f2dda8a3d89e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67135
Add ability to use env var backend for quicker testing (and gloo2 in
the future)
ghstack-source-id: 142274304
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D31878285
fbshipit-source-id: 80ae7107cd631a1a15ebc23262b27d8192cfe4b6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249
Implements CPU offload for model parameters in FSDP.
- CPU offload class with only offload_params attribute is created
- If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__
- In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here.
- In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device.
- The same logic as above applies to the beginning of backwards pass.
- At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU.
Regarding tests:
- We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model.
- Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment.
- Verifies all params are offloaded to CPU after init.
- Verifies all params are offloaded to CPU after forward and backward.
- Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961
ghstack-source-id: 141851903
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D31911085
fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66904
Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrappe, these unit tests are refactored to be aligned with PyTorch commonly used test classes
ghstack-source-id: 141335614
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D31779565
fbshipit-source-id: c727110d1d7570c0ec49e42cadfc9e9a5e440073