Commit Graph

18 Commits

Author SHA1 Message Date
Rohan Varma
a8074a1a0b [Checkpoint] rename apply_ac_wrapper (#85449)
Per title

Differential Revision: [D39714855](https://our.internmc.facebook.com/intern/diff/D39714855/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85449
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Rohan Varma
cc64f64670 [Docs] Minor fix to apply_ac doc (#85448)
Per title

Created from CodeHub with https://fburl.com/edit-in-codehub

Differential Revision: [D39714530](https://our.internmc.facebook.com/intern/diff/D39714530/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85448
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Rohan Varma
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
55ca6901a7 [CheckpointWrapper] Decouple CPU offload (#84907)
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:

```
with save_on_cpu:
    checkpoint(module_forward())
```

which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.

Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.

These wrappers can be composed, i.e. if we have

`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`

we can do

`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`

and inner tensors would be checkpointed while outers will be offloaded.

Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
5b2c03823d Generalize CheckpointWrapper (#83035)
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
2022-08-09 23:35:39 +00:00
Rohan Varma
0c5fdfd95f Revert "Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"" (#80936)
This reverts commit fe361dede4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80936
Approved by: https://github.com/awgu
2022-07-06 22:21:07 +00:00
PyTorch MergeBot
fe361dede4 Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"
This reverts commit 04c50fec1c.

Reverted https://github.com/pytorch/pytorch/pull/80480 on behalf of https://github.com/suo due to Broke master 04c50fec1c, the test failures were not unrelated
2022-07-06 02:43:27 +00:00
Rohan Varma
04c50fec1c [FSDP Optim State] Remove checkpoint prefix (#80480)
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.

Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
2022-07-06 01:17:58 +00:00
Chien-Chin Huang
e0eeb06ec6 Consolidate the naming of named_parameter and state_dict for CheckpointWrapper (#80089)
named_parameter() should return the same parameter names as state_dict() but the current CheckpointWrapper does not enforce this naming rule. This PR resolves this issue.

Differential Revision: [D37344200](https://our.internmc.facebook.com/intern/diff/D37344200/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80089
Approved by: https://github.com/rohan-varma
2022-07-05 22:11:59 +00:00
Rohan Varma
2ede28724d [CheckpointWrapper] Replace generic mod prefix (#79830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2022-06-21 16:01:59 +00:00
Rohan Varma
543919cfc8 Forward attributes to wrapped module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78854

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rohan Varma
44fe851feb [WIP] Fix non-reentrant hooks based checkpointing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78752

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rohan Varma
ec86070922 Checkpoint util
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78704

Approved by: https://github.com/zhaojuanmao
2022-06-10 18:37:36 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rohan Varma
aeacf910b5 [Checkpoint] Rename file (#72748)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72748

Removes underscore from file/class as directory is already private
ghstack-source-id: 149109295

Test Plan: Ci

Reviewed By: samdow

Differential Revision: D34179308

fbshipit-source-id: 8e956f3c83f21159c5e0fcdce09624ecb8a73ac0
(cherry picked from commit adfd8bc357)
2022-02-16 00:08:23 +00:00
Rohan Varma
a197f3fe52 [FSDP/Checkpoint] Activation offload support in checkpoint_wrapper (#70165)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70165

Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33228820

fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
2021-12-21 10:08:18 -08:00
Rohan Varma
79a40b22aa [Checkpoint] Make checkpoint_wrapper an nn.Module (#70164)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70164

Implement Alban's suggestion to make checkpoint_wrapper an nn.Module
instead of patching the forward pass, which is too hacky.
ghstack-source-id: 146011215

Test Plan: IC

Reviewed By: mrshenli

Differential Revision: D33214696

fbshipit-source-id: dc4b3e928d66fbde828ab60d90b314a8048ff7a2
2021-12-20 13:22:28 -08:00
Rohan Varma
c4281cc92d Prototype checkpoint_wrapper (#69955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955

Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.

Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.

Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D33107276

fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
2021-12-16 09:59:19 -08:00