Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70165
Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33228820
fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70164
Implement Alban's suggestion to make checkpoint_wrapper an nn.Module
instead of patching the forward pass, which is too hacky.
ghstack-source-id: 146011215
Test Plan: IC
Reviewed By: mrshenli
Differential Revision: D33214696
fbshipit-source-id: dc4b3e928d66fbde828ab60d90b314a8048ff7a2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955
Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.
Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.
Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D33107276
fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb