as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
Fixes https://github.com/pytorch/pytorch/issues/113717.
When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.
But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
The problem is that we have a subclass (FunctionalTensor) that overrides size/stride calls, causing them to go through __torch_dispatch__.
But when SAC is active, we have _CachingTorchDispatchMode.__torch_dispatch__ active, that intercepts those size/stride calls first, and does something different with them instead of letting FunctionalTensor.__torch_dispatch__ handle them.
This PR updates the SAC torch dispatch mode to know to not handle metadata calls, and let its tensor arguments handle them directly.
Right now, `FunctionalTensor` has a hardcoded list of metadata ops, but we should probably put them somewhere more general.
I'll add better testing before landing this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113380
Approved by: https://github.com/yf225, https://github.com/wanchaol
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728
Approved by: https://github.com/albanD
ghstack dependencies: #110673, #110674, #110675, #110676
The first reland broke internal (failing diff: D49617462).
The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.
Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.
This reverts commit 1b90f07f5a.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
I'm pretty sure this is fixed but I'll run inductor and trunk CI. The failing test in trunk previously was that the selective activation checkpointing code that landed recently assumes that it can detect whether or not AOTAutograd is running by seeing if the inputs to SAC are C++ `FunctionalTensorWrapper`s
previous land broke some inductor trunk tests
This reverts commit 629a628cc8.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109906
Approved by: https://github.com/ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
NOTE: this PR is tagged "not user facing", because it's not ready to be announced externally yet.
This PR implements torch.compile + selective activation checkpoint (SAC) integration, by using `TagActivationCheckpoint` (same backend as torch.compile + full activation checkpoint integration).
TorchDispatchMode based implementation cannot support including inplace ops in the checkpointed region at the moment (the reason for this needs investigation), and there is also no way to ban them (because TorchDispatchMode now only sees "after-functionalization" ops, so can't detect if an op is in-place). Hence we hide torch.compile + SAC behind a flag (`torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint`) and will only use it internally for cases that are known to not have in-place ops. This state won't last too long, because in-place op will at least be able to be detected after Brian's mode reordering and related functionalization changes.
So next steps after this PR:
1. Wait for Brian's mode reordering and related functionalization changes to land, and then try to enable the "inplace ops" unit test for torch.compile + selective activation checkpoint (if it doesn't work, investigate why).
2. Unify selective- and full-checkpoint under TorchDispatchMode based implementation.
Differential Revision: D47497145
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
Approved by: https://github.com/anijain2305
This PR makes some improvements for debuggability of checkpointing:
- improved error messages that are more understandable
- errors are now `CheckpointError` which subclasses `RuntimeError` (only `CheckpointError` triggers debug message, see below)
- stricter error checking by default:
- shapes, dtypes, and device are compared
- we also now error when more tensors are being saved for backward during recompute
- NOTE: checks are relaxed if it is detected that you are doing backward within forward
- shapes, dtype, and device checking can be disabled by passing `determinism_check="none"`
- new debug flag: more helpful error message when `debug=True`
Note:
- cpp stack trace is only included for x86 linux machines
- the error message if cpp stack trace is included can be quite long. For a function checkpointed with 8 operators, the log was around 1300 lines! (should this be hidden behind a flag?)
[Error message when debug='True' (python stack trace only)](https://gist.github.com/soulitzer/3d5e19c7cceae8e22f9bdd625ec39dd4)
[Error message when debug='True' (with python and cpp stacktrace)](https://gist.github.com/soulitzer/ff8fd8c3ccbb2c90dfe3df6d7713b167)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103859
Approved by: https://github.com/albanD
Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func `forward` of `CheckpointFunction`, and I can specify the device type when using it, but the func `apply` of `torch.autograd.Function` does not support `kwargs`, so I added a variable named `_device`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99626
Approved by: https://github.com/soulitzer
Now that we have updated all internal callsites, per https://fb.workplace.com/groups/pytorch.oss.dev/permalink/1635183750239493/ we should raise a warning when use_reentrant is not explicitly passed for 2.1
Deprecation note:
- Not passing in use_reentrant explicitly is now deprecated and will raise a warning. In the future the default value of use-reentrant will be False. To preserve the existing behavior you can pass in use_reentrant=True. It is recommended that you use use_reentrant=False.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100551
Approved by: https://github.com/Skylion007
Why did I choose context manager instead of per-call? Early stopping is not part of the model definition, and depending on how a particular model is used, e.g., with PT2 or not we may or may not want to disable early stopping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96866
Approved by: https://github.com/albanD
Updates:
- ~recommend user to use non-reentrant, mention that reentrant will be deprecated in the future~
- merges all the warnings into a single list of non-reentrant improvements over reentrant
- adds an additional entry to the list about allowing backward inside checkpointed region
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96862
Approved by: https://github.com/albanD
Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
- Accessing _saved_tensors multiple times will silently recompute forward multiple times.
- Accessing ctx.saved_tensor twice in the same backward will now raise an error.
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.
Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature
Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases
Related docs:
- code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
- design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
- retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90105
Approved by: https://github.com/albanD
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Towards fixing https://github.com/pytorch/pytorch/issues/82482
This PR fixes two things:
## 1) memory leak
The .detach() call prevents a true memory leak in some cases where the user function is using multiple ops in a row that save their inputs. The following chain of objects keep each other alive
- the `storage` object
- a recomputed Tensor y
- y's grad_fn FooBackward (in c++)
- FooBackward's SavedVariables (in c++)
- SavedVariable Hook
- the `inner_pack` function
- captures `storage`
Since part of this cycle is in c++, the python gc is not able to break it.
Should THPCppFunction_traverse actually visit it's SavedVariables which in turn should visit their hooks? I think the answer is yes but I haven't dived into which python object is traversing what as if there is non-unique ownership of the c++ object, it makes the traversal a lot trickier. @ezyang do you think we should dive into this more?
In this case, this can be easily solved anyways by storing `y.detach()` in the `storage` object as we don't care about the temporary backward graph that gets created during the second forward call.
## 2) Lifetime of the recomputed buffers
The new storage system is now such that the lifetime of the recomputed buffer is directly linked to the SavedVariable c++ object. Meaning that this buffer will get deleted IIF the SavedVariable is cleared.
This means that we now get the exact same behavior as the version without the saved variable hook where Tensors are saved directly on the SavedVariable object.
This is great as this solves all the cases where the non-checkpoint version used to work but the checkpoint version does not (even double access or retain_graph=True).
The one drawback of this approach though is that the buffer do NOT get cleared when the user passes in `retain_graph=True`! The next backward won't even re-run the forward as it already has all the buffers available. Is this a problem that you think we would need to find a solution for @rohan-varma or it is niche enough that we don't care for now?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82776
Approved by: https://github.com/ezyang, https://github.com/rohan-varma
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69508
Original Phabricator Diff: D32704467 (e032dae329)
Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.
Original PR body:
Resubmission of https://github.com/pytorch/pytorch/pull/62964 with the
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, the tests that we need to add are:
- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)
Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32902634
fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69027
Resubmission of https://github.com/pytorch/pytorch/pull/62964 withe
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, we have added
the following tests:
-[ ] Gradient hooks are called once
ghstack-source-id: 144644859
Test Plan: CI
Reviewed By: pbelevich
Differential Revision: D32704467
fbshipit-source-id: 6eea1cce6b935ef5a0f90b769e395120900e4412