Commit Graph

81 Commits

Author SHA1 Message Date
Yu, Guangye
cdc66e9dc3 refactor autocast python APIs (#124479)
# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
2024-04-25 14:33:33 +00:00
Yu, Guangye
25f321b84f Refactor autocast C++ APIs to be device-agnostic (#124359)
# Motivation
This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast  **C++** APIs.
In C++ side,
- `is_enabled()` -> `is_enabled(device_type)`.
- `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`.
- `get_autocast_dtype()` -> `get_autocast_dtype(device_type)`
- `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)`

These following C++ APIs are deprecated and should be removed in PyTorch 2.5
- `is_cpu_enabled`
- `set_cpu_enabled`
- `get_autocast_cpu_dtype`
- `set_autocast_cpu_dtype`
- `is_xpu_enabled`
- `set_xpu_enabled`
- `get_autocast_xpu_dtype`
- `set_autocast_xpu_dtype`
- `is_ipu_enabled`
- `set_ipu_enabled`
- `get_autocast_ipu_dtype`
- `set_autocast_ipu_dtype`
- `is_hpu_enabled`
- `set_hpu_enabled`
- `get_autocast_hpu_dtype`
- `set_autocast_hpu_dtype`
- `is_xla_enabled`
- `set_xla_enabled`
- `get_autocast_xla_dtype`
- `set_autocast_xla_dtype`
- `is_privateuseone_enabled`
- `set_privateuseone_enabled`
- `get_autocast_privateuseone_dtype`
- `set_autocast_privateuseone_dtype`

In Python side,
provide 4 generic autocast APIs:
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type, new_enabled)`
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type, dtype)`

# Additional Context
We will submit another PR to refactor autocast **Python** APIs based on this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359
Approved by: https://github.com/jgong5, https://github.com/albanD
2024-04-23 10:38:50 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Edward Z. Yang
26a9b05bce Set stacklevel on checkpoint warning (#123717)
Partially addresses https://github.com/pytorch/pytorch/issues/123626

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123717
Approved by: https://github.com/Skylion007
2024-04-10 17:25:06 +00:00
Andrew Gu
1d6fc0d4de Fixed _infer_device_type warning in checkpoint (#122726)
Previously, we were checking `len(device_types)` where `device_types` is a `list`. This meant that if there were multiple inputs, we would see something like `device_types = ["cuda", "cuda"]` and a false positive warning. We should check `len(set(device_types))`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122726
Approved by: https://github.com/soulitzer
2024-03-27 18:38:42 +00:00
Pritam Damania
512251c8f3 Use tree_map to get device ids and device types for activation checkpointing (#121462)
`get_device_states` doesn't recursively look into nested lists/dicts to find tensors. As a result, activation checkpointing for such inputs results in silent incorrect results as `get_device_states` returns an empty result and no rng is saved as a result here: https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L188 since `fwd_device_states` is empty.

Fixed this by using `tree_map` for both `get_device_states` and `_infer_device_type`. Also added appropriate unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121462
Approved by: https://github.com/soulitzer
2024-03-20 21:09:21 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
soulitzer
5866284d4a Make not passing use_reentrant back to warning instead of erroring and clarify docs (#116710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116710
Approved by: https://github.com/albanD
ghstack dependencies: #116523
2024-01-09 20:58:49 +00:00
Ghassene Jebali
e728ebb66d Small docstring fix (#116947)
Fix a small typo in the docstring of checkpoint function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116947
Approved by: https://github.com/Skylion007, https://github.com/kit1980
2024-01-08 23:51:59 +00:00
soulitzer
4d6a1ad400 Activation checkpoint and checkpoint_sequential errors if use_reentrant not passed explicitly (#115868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115868
Approved by: https://github.com/albanD
ghstack dependencies: #115438
2023-12-20 15:23:44 +00:00
Wanchao Liang
dd367b7c8f check tensor subclass when using torch.compile + SAC (#115960)
as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
2023-12-18 17:49:06 +00:00
Will Feng
495054545c Allow preserve_rng_state=True when torch.compile + selective checkpointing + CUDA (#113718)
Fixes https://github.com/pytorch/pytorch/issues/113717.

When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.

But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
2023-12-09 01:47:25 +00:00
soulitzer
a7bcc78bff Make it clearer that current selective AC is PT2-only and private (#115081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115081
Approved by: https://github.com/albanD
2023-12-04 23:01:22 +00:00
soulitzer
c1d9d4a2b5 checkpoint_sequential warns if use_reentrant not passed explicitly (#114158)
Use warning text for deprecation message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114158
Approved by: https://github.com/albanD
2023-11-20 23:08:44 +00:00
ChanBong
5e10dd2c78 fix docstring issues in torch.utils (#113335)
Fixes #112634

Fixes all the issues listed except in `torch/utils/_pytree.py` as the file no longer exists.

### Error counts

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/utils/collect_env.py` | 39 | 25|
|`torch/utils/cpp_extension.py` | 51 | 13|
|`torch/utils/flop_counter.py` | 25 | 8|
|`torch/utils/_foreach_utils.py.py` | 2 | 0|
|`torch/utils/_python_dispatch.py.py` | 26 | 25|
|`torch/utils/backend_registration.py` | 15 | 4|
|`torch/utils/checkpoint.py` | 29 | 21|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113335
Approved by: https://github.com/ezyang
2023-11-13 19:37:25 +00:00
Brian Hirsh
7064fbf1ea Fix selective activation checkpointing with subclasses that override sizes() (#113380)
The problem is that we have a subclass (FunctionalTensor) that overrides size/stride calls, causing them to go through __torch_dispatch__.

But when SAC is active, we have _CachingTorchDispatchMode.__torch_dispatch__ active, that intercepts those size/stride calls first, and does something different with them instead of letting FunctionalTensor.__torch_dispatch__ handle them.

This PR updates the SAC torch dispatch mode to know to not handle metadata calls, and let its tensor arguments handle them directly.

Right now, `FunctionalTensor` has a hardcoded list of metadata ops, but we should probably put them somewhere more general.

I'll add better testing before landing this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113380
Approved by: https://github.com/yf225, https://github.com/wanchaol
2023-11-10 04:12:50 +00:00
soulitzer
c9eb8d8d90 Add set_checkpoint_debug_enabled that overrides local setting (#110728)
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728
Approved by: https://github.com/albanD
ghstack dependencies: #110673, #110674, #110675, #110676
2023-10-11 02:12:31 +00:00
Brian Hirsh
b457e3f79a Reland attempt 2 of "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)" (#110079)
The first reland broke internal (failing diff: D49617462).

The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.

Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.

This reverts commit 1b90f07f5a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
2023-10-03 18:50:25 +00:00
PyTorch MergeBot
1b90f07f5a Revert "Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)"
This reverts commit d0fe8fa5db.

Reverted https://github.com/pytorch/pytorch/pull/109906 on behalf of https://github.com/atalman due to Breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/109906#issuecomment-1735416852))
2023-09-26 12:10:25 +00:00
Brian Hirsh
d0fe8fa5db Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)
I'm pretty sure this is fixed but I'll run inductor and trunk CI. The failing test in trunk previously was that the selective activation checkpointing code that landed recently assumes that it can detect whether or not AOTAutograd is running by seeing if the inputs to SAC are C++ `FunctionalTensorWrapper`s

previous land broke some inductor trunk tests

This reverts commit 629a628cc8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109906
Approved by: https://github.com/ezyang
2023-09-25 14:53:54 +00:00
Will Feng
3f3e353885 torch.compile + selective activation checkpointing (#105489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489

NOTE: this PR is tagged "not user facing", because it's not ready to be announced externally yet.

This PR implements torch.compile + selective activation checkpoint (SAC) integration, by using `TagActivationCheckpoint` (same backend as torch.compile + full activation checkpoint integration).

TorchDispatchMode based implementation cannot support including inplace ops in the checkpointed region at the moment (the reason for this needs investigation), and there is also no way to ban them (because TorchDispatchMode now only sees "after-functionalization" ops, so can't detect if an op is in-place). Hence we hide torch.compile + SAC behind a flag (`torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint`) and will only use it internally for cases that are known to not have in-place ops. This state won't last too long, because in-place op will at least be able to be detected after Brian's mode reordering and related functionalization changes.
So next steps after this PR:
1. Wait for Brian's mode reordering and related functionalization changes to land, and then try to enable the "inplace ops" unit test for torch.compile + selective activation checkpoint (if it doesn't work, investigate why).
2. Unify selective- and full-checkpoint under TorchDispatchMode based implementation.

Differential Revision: D47497145

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
Approved by: https://github.com/anijain2305
2023-09-21 16:24:11 +00:00
soulitzer
884c03d240 Improve activation checkpoint docs wording (#107296)
This helps eliminate some confusion around "intermediates" and whether module outputs are handled as well. See this internal post https://fb.workplace.com/groups/1405155842844877/permalink/7327505913943144/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107296
Approved by: https://github.com/albanD
2023-08-16 17:36:52 +00:00
Rohan Varma
5d70fe0165 [Composable] Use non-reentrant generator, remove reentrant (#105176)
Removes reentrant support for the composable checkpoint, as
non-reentrant is the recommended approach and we should use this when rolling
out composable checkpoint API.

Also removes the standalone implementation for non-reentrant and instead uses
the generator from below diff to reuse the original implemenetation.

Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 07:03:03 +00:00
Justin Chu
abc1cadddb [BE] Enable ruff's UP rules and autoformat utils/ (#105424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105424
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-18 20:17:25 +00:00
Rohan Varma
b7b44e766b [Checkpoint] Separate implementation into generator (#105101)
Separates the non-reentrant AC implementation into a generator so that
other APIs such as composable checkpoint API can use the generator as pre and
post forward logic.

Differential Revision: [D47419387](https://our.internmc.facebook.com/intern/diff/D47419387/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105101
Approved by: https://github.com/soulitzer
2023-07-14 06:27:13 +00:00
soulitzer
91dcc3b272 Fix activation checkpoint for mps (#104787)
Fixes https://github.com/pytorch/pytorch/issues/104478

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104787
Approved by: https://github.com/albanD
2023-07-08 14:57:05 +00:00
Animesh Jain
8c191d8eef [dynamo][ac] Reland #104397 - Remove disable monkeypatching of utils.checkpoint (#104665)
NO CHANGE from before. The ancestor diff was reverted, so this diff got reverted as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104665
Approved by: https://github.com/wconstab
2023-07-06 00:48:02 +00:00
PyTorch MergeBot
40f53912cf Revert "[dynamo][ac] Remove disable monkeypatching of utils.checkpoint (#104397)"
This reverts commit 537a6c0651.

Reverted https://github.com/pytorch/pytorch/pull/104397 on behalf of https://github.com/huydhn due to This has been reverted internally by D47216591, so I need to also revert it on OSS to keep them in sync ([comment](https://github.com/pytorch/pytorch/pull/104397#issuecomment-1621086360))
2023-07-05 06:11:08 +00:00
Animesh Jain
537a6c0651 [dynamo][ac] Remove disable monkeypatching of utils.checkpoint (#104397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104397
Approved by: https://github.com/wconstab
2023-06-30 02:27:06 +00:00
soulitzer
73c927f901 Improve debuggability of activation checkpoint (#103859)
This PR makes some improvements for debuggability of checkpointing:
- improved error messages that are more understandable
- errors are now `CheckpointError` which subclasses `RuntimeError` (only `CheckpointError` triggers debug message, see below)
- stricter error checking by default:
   - shapes, dtypes, and device are compared
   - we also now error when more tensors are being saved for backward during recompute
   - NOTE: checks are relaxed if it is detected that you are doing backward within forward
 - shapes, dtype, and device checking can be disabled by passing `determinism_check="none"`
 - new debug flag: more helpful error message when `debug=True`

Note:
- cpp stack trace is only included for x86 linux machines
- the error message if cpp stack trace is included can be quite long. For a function checkpointed with 8 operators, the log was around 1300 lines! (should this be hidden behind a flag?)

[Error message when debug='True' (python stack trace only)](https://gist.github.com/soulitzer/3d5e19c7cceae8e22f9bdd625ec39dd4)

[Error message when debug='True' (with python and cpp stacktrace)](https://gist.github.com/soulitzer/ff8fd8c3ccbb2c90dfe3df6d7713b167)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103859
Approved by: https://github.com/albanD
2023-06-22 03:57:36 +00:00
Rohan Varma
60547fcbee Autoformat torch/utils/checkpoint (#101649)
Per title

Differential Revision: [D45933467](https://our.internmc.facebook.com/intern/diff/D45933467/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101649
Approved by: https://github.com/Skylion007, https://github.com/soulitzer
2023-05-18 21:55:05 +00:00
soulitzer
70ef0bb45a Fix checkpoint doc small formatting issue (#101419)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101419
Approved by: https://github.com/albanD
2023-05-15 21:33:56 +00:00
soulitzer
98f6b815b7 [BE] Make some simplifications to torch.utils.checkpoint logic (#101193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101193
Approved by: https://github.com/albanD
2023-05-12 04:35:22 +00:00
shibo
6aeb85add8 add checkpoint support for custom device (#99626)
Fixes #ISSUE_NUMBER
1、add checkpoint support for custom device
2、add a device argument, I want to add a device="cuda" parameter to the func `forward` of `CheckpointFunction`, and I can specify the device type when using it, but the func `apply` of `torch.autograd.Function` does not support `kwargs`, so I added a variable named `_device`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99626
Approved by: https://github.com/soulitzer
2023-05-04 00:23:42 +00:00
soulitzer
e552b91286 torch.utils.checkpoint warns if user does not pass use_reentrant explicitly (#100551)
Now that we have updated all internal callsites, per https://fb.workplace.com/groups/pytorch.oss.dev/permalink/1635183750239493/ we should raise a warning when use_reentrant is not explicitly passed for 2.1

Deprecation note:
- Not passing in use_reentrant explicitly is now deprecated and will raise a warning. In the future the default value of use-reentrant will be False. To preserve the existing behavior you can pass in use_reentrant=True. It is recommended that you use use_reentrant=False.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100551
Approved by: https://github.com/Skylion007
2023-05-03 20:48:07 +00:00
Kazuaki Ishizaki
622a11d512 Fix typos under torch/utils directory (#97516)
This PR fixes typos in comments and messages of `.py` files under `torch/utils` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97516
Approved by: https://github.com/ezyang
2023-03-24 16:53:39 +00:00
soulitzer
7a8b691388 Make early stop the default for checkpoint and expose a way to disable (#96866)
Why did I choose context manager instead of per-call? Early stopping is not part of the model definition, and depending on how a particular model is used, e.g., with PT2 or not we may or may not want to disable early stopping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96866
Approved by: https://github.com/albanD
2023-03-22 20:03:56 +00:00
soulitzer
89d116d961 [BE][docs]Improve and update checkpoint documentation (#96862)
Updates:
- ~recommend user to use non-reentrant, mention that reentrant will be deprecated in the future~
- merges all the warnings into a single list of non-reentrant improvements over reentrant
- adds an additional entry to the list about allowing backward inside checkpointed region

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96862
Approved by: https://github.com/albanD
2023-03-22 16:53:29 +00:00
soulitzer
f3db2a6341 Expose API to specify custom context manager for checkpoint (#96783)
Per [design](https://docs.google.com/document/d/1v-yqRqiWA6dIUOw5OpqFs2PqSQIbDEkwRPGk9FcYnxg/edit) we want (1) to allow the user to pass in a function that returns two context managers (2) a per-call API only for now, and (3) do not upstream selective checkpoint for the short term.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96783
Approved by: https://github.com/albanD
2023-03-15 20:37:33 +00:00
soulitzer
d30db9a251 Replace non-reentrant checkpoint with a rewrite that can be nested and contain grad (#90105)
Changes:
- bc-breaking change: The main difference between this and the old non-reentrant impl that it replaces is that we clear recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications:
   - Accessing _saved_tensors multiple times will silently recompute forward multiple times.
   - Accessing ctx.saved_tensor twice in the same backward will now raise an error.
- To avoid dealing with the potential consequences, early stopping has been hidden behind a global flag that is by default False, and can be enabled via a context manager. We can remove this in a follow up. Some features of nesting as a result do not work by default.

Before land:
- import to check for more bc-breakingness
- implement any workarounds for the bc-breaking-ness, if we decide on any
- update docs to reflect new lifetime of recomputed variables
- update docs to mention the early stop feature

Follow ups:
- enable early-stopping by default
- update docs/tutorial to feature nested use cases

Related docs:
  - code comment: https://github.com/pytorch/pytorch/pull/90105/files#diff-9dcd955620b52ce128e18e3567be88edbb238810460d1288a86fabc20e483b30R448
  - design doc: https://docs.google.com/document/d/1UDLhTNv6_kvuDTRlsjfj9WdqtNaQNr8ahrvdBIB6914/edit#
  - retains_grad <> checkpiont https://docs.google.com/document/d/1maiGmuFUxysQL0AdYUU88kngAaXh_L0XpDcLDh_5Ors/edit

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90105
Approved by: https://github.com/albanD
2023-03-14 20:38:36 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Rohan Varma
d93b1b9c4e Address feedback from previous PR (#86622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86622
Approved by: https://github.com/albanD
2022-10-10 18:53:41 +00:00
Rohan Varma
7a411952fb CheckpointSequential support non-reentrant (#86331)
Closes https://github.com/pytorch/pytorch/issues/86328

Adds `use_reentrant` argument to `checkpoint_sequential`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86331
Approved by: https://github.com/zhaojuanmao, https://github.com/albanD
2022-10-06 23:10:18 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
albanD
7dd795cbed Prevent ref cycle creation in inner hook (#82776)
Towards fixing https://github.com/pytorch/pytorch/issues/82482

This PR fixes two things:

## 1) memory leak
The .detach() call prevents a true memory leak in some cases where the user function is using multiple ops in a row that save their inputs. The following chain of objects keep each other alive
- the `storage` object
- a recomputed Tensor y
- y's grad_fn FooBackward (in c++)
- FooBackward's SavedVariables (in c++)
- SavedVariable Hook
- the `inner_pack` function
- captures `storage`

Since part of this cycle is in c++, the python gc is not able to break it.
Should THPCppFunction_traverse actually visit it's SavedVariables which in turn should visit their hooks? I think the answer is yes but I haven't dived into which python object is traversing what as if there is non-unique ownership of the c++ object, it makes the traversal a lot trickier. @ezyang do you think we should dive into this more?

In this case, this can be easily solved anyways by storing `y.detach()` in the `storage` object as we don't care about the temporary backward graph that gets created during the second forward call.

## 2) Lifetime of the recomputed buffers
The new storage system is now such that the lifetime of the recomputed buffer is directly linked to the SavedVariable c++ object. Meaning that this buffer will get deleted IIF the SavedVariable is cleared.
This means that we now get the exact same behavior as the version without the saved variable hook where Tensors are saved directly on the SavedVariable object.

This is great as this solves all the cases where the non-checkpoint version used to work but the checkpoint version does not (even double access or retain_graph=True).

The one drawback of this approach though is that the buffer do NOT get cleared when the user passes in `retain_graph=True`! The next backward won't even re-run the forward as it already has all the buffers available. Is this a problem that you think we would need to find a solution for @rohan-varma or it is niche enough that we don't care for now?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82776
Approved by: https://github.com/ezyang, https://github.com/rohan-varma
2022-08-06 00:31:22 +00:00
ProGamerGov
8def154e00 Fix multiple docstring type mistakes (#82474)
### Description

* Docstrings using `(tuple of ints)` shows up as `(tuple of python:ints)`, so I fixed them by making the `int` no longer plural. Example: https://pytorch.org/docs/stable/generated/torch.permute.html#torch.permute
* A docstring type in JIT had one of its types incorrectly highlighted as code. Example: https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script
* I found some docstring type usages of `string` that had not yet been converted to `str` after #82410
* Some docstrings incorrectly listed their defaults inside the docstring types.
* I also found a docstring that was missing its type

### Testing
No testing should be required.

---

In the developer guidelines, there should probably be standards listed for the docstring types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82474
Approved by: https://github.com/albanD
2022-07-29 17:45:37 +00:00
Rohan Varma
98cad3d305 [Checkpoint] Fix autocasting (#81766)
Add support for the correct autocasting in the non-reentrant checkpoint as it exists in the reentrant-version.

This was noticed by @awgu.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81766
Approved by: https://github.com/albanD
2022-07-22 21:33:56 +00:00
Rohan Varma
e14941ef79 Add kwarg support for no_reentrant checkpoint (#80987)
Supports kwargs input to function when `torch.utils.checkpoint` with use_reentrant=False. This is required to unblock T5 activation checkpointing and MetaSeq use cases.

Closes https://github.com/pytorch/pytorch/issues/79887
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80987
Approved by: https://github.com/zhaojuanmao
2022-07-09 05:07:13 +00:00