Summary:
Currently, we cannot run a checkpointed function with None argument.
```python
out = torch.utils.checkpoint.checkpoint(run_fn, input_var, None)
```
```
File "/home/tunz/anaconda3/envs/torchdev/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 14, in detach_variable
x = inp.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
```
This PR makes checkpoint function to safely handle None argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17969
Differential Revision: D14475148
Pulled By: ezyang
fbshipit-source-id: 9afe9e9aac511a6df1e1620e9ac341536890d451
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278
In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.
We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.
Reviewed By: ezyang
Differential Revision: D13144672
fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
Summary:
This issue was noticed, and fix proposed, by raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.
Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.
I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification). `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?
**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253
Differential Revision: D13166665
Pulled By: soumith
fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13141
This is an example diff to show what lint rules are being applied.
Reviewed By: mingzhe09088
Differential Revision: D10858478
fbshipit-source-id: cbeb013f10f755b0095478adf79366e7cf7836ff
* Autograd container for trading compute for memory
* add a unit test for checkpoint
* address comments
* address review comments
* adding some docs for the checkpoint api
* more comments
* more comments
* repro bug
* Fix a subtle bug/apply some review comments
* Update checkpoint.py
* Run everything in grad mode
* fix flake and chunk=1
* use imperative backward as per discussion
* remove Variable and also add models and test for models
* Add a simple thread local variable to check for autograd grad mode
* remove models and models test after debugging
* address review comments
* address more comments
* address more comments