This PR makes some improvements for debuggability of checkpointing:
- improved error messages that are more understandable
- errors are now `CheckpointError` which subclasses `RuntimeError` (only `CheckpointError` triggers debug message, see below)
- stricter error checking by default:
- shapes, dtypes, and device are compared
- we also now error when more tensors are being saved for backward during recompute
- NOTE: checks are relaxed if it is detected that you are doing backward within forward
- shapes, dtype, and device checking can be disabled by passing `determinism_check="none"`
- new debug flag: more helpful error message when `debug=True`
Note:
- cpp stack trace is only included for x86 linux machines
- the error message if cpp stack trace is included can be quite long. For a function checkpointed with 8 operators, the log was around 1300 lines! (should this be hidden behind a flag?)
[Error message when debug='True' (python stack trace only)](https://gist.github.com/soulitzer/3d5e19c7cceae8e22f9bdd625ec39dd4)
[Error message when debug='True' (with python and cpp stacktrace)](https://gist.github.com/soulitzer/ff8fd8c3ccbb2c90dfe3df6d7713b167)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103859
Approved by: https://github.com/albanD
Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func `forward` of `CheckpointFunction`, and I can specify the device type when using it, but the func `apply` of `torch.autograd.Function` does not support `kwargs`, so I added a variable named `_device`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99626
Approved by: https://github.com/soulitzer
Now that we have updated all internal callsites, per https://fb.workplace.com/groups/pytorch.oss.dev/permalink/1635183750239493/ we should raise a warning when use_reentrant is not explicitly passed for 2.1
Deprecation note:
- Not passing in use_reentrant explicitly is now deprecated and will raise a warning. In the future the default value of use-reentrant will be False. To preserve the existing behavior you can pass in use_reentrant=True. It is recommended that you use use_reentrant=False.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100551
Approved by: https://github.com/Skylion007
Why did I choose context manager instead of per-call? Early stopping is not part of the model definition, and depending on how a particular model is used, e.g., with PT2 or not we may or may not want to disable early stopping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96866
Approved by: https://github.com/albanD
Updates:
- ~recommend user to use non-reentrant, mention that reentrant will be deprecated in the future~
- merges all the warnings into a single list of non-reentrant improvements over reentrant
- adds an additional entry to the list about allowing backward inside checkpointed region
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96862
Approved by: https://github.com/albanD
Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
- Accessing _saved_tensors multiple times will silently recompute forward multiple times.
- Accessing ctx.saved_tensor twice in the same backward will now raise an error.
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.
Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature
Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases
Related docs:
- code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
- design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
- retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90105
Approved by: https://github.com/albanD
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Towards fixing https://github.com/pytorch/pytorch/issues/82482
This PR fixes two things:
## 1) memory leak
The .detach() call prevents a true memory leak in some cases where the user function is using multiple ops in a row that save their inputs. The following chain of objects keep each other alive
- the `storage` object
- a recomputed Tensor y
- y's grad_fn FooBackward (in c++)
- FooBackward's SavedVariables (in c++)
- SavedVariable Hook
- the `inner_pack` function
- captures `storage`
Since part of this cycle is in c++, the python gc is not able to break it.
Should THPCppFunction_traverse actually visit it's SavedVariables which in turn should visit their hooks? I think the answer is yes but I haven't dived into which python object is traversing what as if there is non-unique ownership of the c++ object, it makes the traversal a lot trickier. @ezyang do you think we should dive into this more?
In this case, this can be easily solved anyways by storing `y.detach()` in the `storage` object as we don't care about the temporary backward graph that gets created during the second forward call.
## 2) Lifetime of the recomputed buffers
The new storage system is now such that the lifetime of the recomputed buffer is directly linked to the SavedVariable c++ object. Meaning that this buffer will get deleted IIF the SavedVariable is cleared.
This means that we now get the exact same behavior as the version without the saved variable hook where Tensors are saved directly on the SavedVariable object.
This is great as this solves all the cases where the non-checkpoint version used to work but the checkpoint version does not (even double access or retain_graph=True).
The one drawback of this approach though is that the buffer do NOT get cleared when the user passes in `retain_graph=True`! The next backward won't even re-run the forward as it already has all the buffers available. Is this a problem that you think we would need to find a solution for @rohan-varma or it is niche enough that we don't care for now?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82776
Approved by: https://github.com/ezyang, https://github.com/rohan-varma
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69508
Original Phabricator Diff: D32704467 (e032dae329)
Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.
Original PR body:
Resubmission of https://github.com/pytorch/pytorch/pull/62964 with the
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, the tests that we need to add are:
- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)
Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32902634
fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69027
Resubmission of https://github.com/pytorch/pytorch/pull/62964 withe
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, we have added
the following tests:
-[ ] Gradient hooks are called once
ghstack-source-id: 144644859
Test Plan: CI
Reviewed By: pbelevich
Differential Revision: D32704467
fbshipit-source-id: 6eea1cce6b935ef5a0f90b769e395120900e4412
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52422
As mentioned in https://github.com/pytorch/pytorch/issues/52415,
`torch.utils.checkpoint` doesn't support checkpointing for functions which have
non-tensor inputs and outputs.
This PR resolves this issue by ensuring the autograd machinery ignores the
non-tensor inputs and outputs and processes the tensors accordingly.
ghstack-source-id: 124406867
Test Plan:
1) unit test
2) waitforbuildbot
Reviewed By: albanD
Differential Revision: D26507228
fbshipit-source-id: 0a5a1591570814176185362e83ad18dabd9c84b0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45934https://pytorch.org/docs/stable/checkpoint.html pytorch checkpoint requires all input to the function being checkpointed to requires_grad, but this assumption is not necessarily try. consider the following two examples
```
output = MultiheadedMaskedAtten(input, mask)
output = LSTM(input, seq_length)
```
both length and mask are tensors that won't requires grad, currently if you try to checkpoint torch.autograd.backward will complain
```
File "/mnt/xarfuse/uid-124297/7d159c34-seed-nspid4026531836-ns-4026531840/torch/autograd/function.py
", line 87, in apply
return self._forward_cls.backward(self, *args)
File "/mnt/xarfuse/uid-124297/7d159c34-seed-nspid4026531836-ns-4026531840/torch/utils/checkpoint.py"
, line 99, in backward
torch.autograd.backward(outputs, args)
File "/mnt/xarfuse/uid-124297/7d159c34-seed-nspid4026531836-ns-4026531840/torch/autograd/__init__.py
", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 1 of tensors does not require grad and does not have a grad_fn
```
this diff allows skipping the non-grad-requiring tensor when running autograd.backward.
added documentation for this feature as well.
Test Plan: added unit test to make sure partial tensor grads can be used in checkpoint().
Differential Revision: D24094764
fbshipit-source-id: 6557e8e74132d5a392526adc7b57b6998609ed12
Summary:
Fix typos in torch.utils/_benchmark/README.md
Add empty __init__.py to examples folder to make example invocations from README.md correct
Fixed uniform distribution logic generation when mixval and maxval are None
Fixes https://github.com/pytorch/pytorch/issues/42984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42960
Reviewed By: seemethere
Differential Revision: D23095399
Pulled By: malfet
fbshipit-source-id: 0546ce7299b157d9a1f8634340024b10c4b7e7de
Summary:
See https://discuss.pytorch.org/t/training-with-gradient-checkpoints-torch-utils-checkpoint-appears-to-reduce-performance-of-model/78102/3?u=jwl for details.
Updated the docs to warn users about issues with checkpointing models that use `detach()` or `torch.no_grad()` to freeze their model layers/weights during training. When they do this, training with `checkpoint` will fail as it forces the outputs to require gradients when the model itself does not. Hence, during the backward pass it will output the error:
```
[4]<stderr>:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
```
Maybe it is possible to fix this directly in the code, but I am not sure how in the current codebase.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37266
Differential Revision: D21262558
Pulled By: mrshenli
fbshipit-source-id: 529cf370534504baf8937ef17dac5d6916fbf5ae
Summary:
To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985
Differential Revision: D18809875
Pulled By: albanD
fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
Summary:
I've reported inconsistency between `checkpoint_sequential` and `nn.Sequential` at https://github.com/pytorch/pytorch/issues/19260. Both should provide the same input signature but they don't. I think the consistency is important and I agree with apaszke that `nn.Sequential`'s semantics should be kept instead of `checkpoint_sequential`.
I hope `checkpoint_sequential` raises `TypeError` on variadic arguments since PyTorch 1.2.0. But for now, it's okay just to warn as `DeprecationWarning`. I've talked about this approach with soumith.
Please review this pull request. Any comment will be my pleasure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21006
Differential Revision: D15530801
Pulled By: soumith
fbshipit-source-id: 0ceb2cc6a17dcc547d0d00ebaf9df8603be53183
Summary:
Currently, we cannot run a checkpointed function with None argument.
```python
out = torch.utils.checkpoint.checkpoint(run_fn, input_var, None)
```
```
File "/home/tunz/anaconda3/envs/torchdev/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 14, in detach_variable
x = inp.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
```
This PR makes checkpoint function to safely handle None argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17969
Differential Revision: D14475148
Pulled By: ezyang
fbshipit-source-id: 9afe9e9aac511a6df1e1620e9ac341536890d451
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278
In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.
We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.
Reviewed By: ezyang
Differential Revision: D13144672
fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
Summary:
This issue was noticed, and fix proposed, by raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.
Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.
I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification). `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?
**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253
Differential Revision: D13166665
Pulled By: soumith
fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7