Summary:
To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985
Differential Revision: D18809875
Pulled By: albanD
fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
Summary:
I've reported inconsistency between `checkpoint_sequential` and `nn.Sequential` at https://github.com/pytorch/pytorch/issues/19260. Both should provide the same input signature but they don't. I think the consistency is important and I agree with apaszke that `nn.Sequential`'s semantics should be kept instead of `checkpoint_sequential`.
I hope `checkpoint_sequential` raises `TypeError` on variadic arguments since PyTorch 1.2.0. But for now, it's okay just to warn as `DeprecationWarning`. I've talked about this approach with soumith.
Please review this pull request. Any comment will be my pleasure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21006
Differential Revision: D15530801
Pulled By: soumith
fbshipit-source-id: 0ceb2cc6a17dcc547d0d00ebaf9df8603be53183
Summary:
Currently, we cannot run a checkpointed function with None argument.
```python
out = torch.utils.checkpoint.checkpoint(run_fn, input_var, None)
```
```
File "/home/tunz/anaconda3/envs/torchdev/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 14, in detach_variable
x = inp.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
```
This PR makes checkpoint function to safely handle None argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17969
Differential Revision: D14475148
Pulled By: ezyang
fbshipit-source-id: 9afe9e9aac511a6df1e1620e9ac341536890d451
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278
In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.
We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.
Reviewed By: ezyang
Differential Revision: D13144672
fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
Summary:
This issue was noticed, and fix proposed, by raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.
Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.
I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification). `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?
**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253
Differential Revision: D13166665
Pulled By: soumith
fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13141
This is an example diff to show what lint rules are being applied.
Reviewed By: mingzhe09088
Differential Revision: D10858478
fbshipit-source-id: cbeb013f10f755b0095478adf79366e7cf7836ff
* Autograd container for trading compute for memory
* add a unit test for checkpoint
* address comments
* address review comments
* adding some docs for the checkpoint api
* more comments
* more comments
* repro bug
* Fix a subtle bug/apply some review comments
* Update checkpoint.py
* Run everything in grad mode
* fix flake and chunk=1
* use imperative backward as per discussion
* remove Variable and also add models and test for models
* Add a simple thread local variable to check for autograd grad mode
* remove models and models test after debugging
* address review comments
* address more comments
* address more comments