This PR adds a couple of missing words in the Checkpointing documentation, it doesn't have a specific issue number related to it.
Changes are:
- "backward." -> "backward propagation."
- "to be advanced than" -> "to be more advanced than"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120196
Approved by: https://github.com/soulitzer
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728
Approved by: https://github.com/albanD
ghstack dependencies: #110673, #110674, #110675, #110676
Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func `forward` of `CheckpointFunction`, and I can specify the device type when using it, but the func `apply` of `torch.autograd.Function` does not support `kwargs`, so I added a variable named `_device`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99626
Approved by: https://github.com/soulitzer
Summary:
This issue was noticed, and fix proposed, by raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.
Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.
I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification). `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?
**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253
Differential Revision: D13166665
Pulled By: soumith
fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7
* Autograd container for trading compute for memory
* add a unit test for checkpoint
* address comments
* address review comments
* adding some docs for the checkpoint api
* more comments
* more comments
* repro bug
* Fix a subtle bug/apply some review comments
* Update checkpoint.py
* Run everything in grad mode
* fix flake and chunk=1
* use imperative backward as per discussion
* remove Variable and also add models and test for models
* Add a simple thread local variable to check for autograd grad mode
* remove models and models test after debugging
* address review comments
* address more comments
* address more comments