Motivation: Generalize unit tests so that can be executed for cuda and non cuda devices.
Depedency : #133209 Merged now.
There was a #135242 for these changes and closed due to in correct commits. I have incoroprated the changes as suggested in comments.
@kwen2501 @zeshengzong Please review the changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139184
Approved by: https://github.com/kwen2501
Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.
Now, offload to CPU + checkpoint can be composed together, such as
```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```
Will polish / add tests if this proposal sounds good.
Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:
```
with save_on_cpu:
checkpoint(module_forward())
```
which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.
Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.
These wrappers can be composed, i.e. if we have
`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`
we can do
`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`
and inner tensors would be checkpointed while outers will be offloaded.
Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu