pytorch/torch/distributed/algorithms/_checkpoint
Rohan Varma bdefa260b2 [RFC] Separate CPU offload activation to its own wrapper (#85459)
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.

Now, offload to CPU + checkpoint can be composed together, such as

```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```

Will polish / add tests if this proposal sounds good.

Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
2022-10-15 05:19:28 +00:00
..
__init__.py
checkpoint_wrapper.py [RFC] Separate CPU offload activation to its own wrapper (#85459) 2022-10-15 05:19:28 +00:00