Commit Graph

108 Commits

Author SHA1 Message Date
Xuehai Pan
d40aaa42ee [BE][16/16] fix typos in torch/ (torch/utils/) (#156606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156606
Approved by: https://github.com/albanD
ghstack dependencies: #156318, #156320, #156602, #156604
2025-07-02 22:55:29 +00:00
soulitzer
3580b8dde4 [BE] Mention debug=True in AC error messages (#155593)
See https://github.com/pytorch/pytorch/issues/155171#issuecomment-2949415407
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155593
Approved by: https://github.com/janeyx99
2025-06-11 00:32:41 +00:00
drisspg
80703ca332 [FlexAttention] Allow dispatch to SAC for flex (#150080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150080
Approved by: https://github.com/zou3519
2025-06-05 04:34:27 +00:00
soulitzer
b040d63ce4 Prevent SAC cache from being kept alive by reference cycle (#154651)
Fixes https://github.com/pytorch/pytorch/issues/154642
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154651
Approved by: https://github.com/xmfan
2025-05-29 22:27:35 +00:00
Simon Fan
5aea57d653 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-16 01:37:48 +00:00
PyTorch MergeBot
236b08cbf8 Revert "[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)"
This reverts commit 4863e5c843.

Reverted https://github.com/pytorch/pytorch/pull/153300 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see fa8543454a/1 ([comment](https://github.com/pytorch/pytorch/pull/153300#issuecomment-2884489459))
2025-05-15 16:58:52 +00:00
Simon Fan
4863e5c843 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-15 08:10:35 +00:00
Brian Hirsh
5abe74857a SAC: fix recompute tag propagation for ops with list[tensor] inputs (#152195)
There's an "are we compiling" check in SAC, which we rely on to know when to propagate recompute tags during tracing.

This check was a bit brittle, and missed cases where input ops accept list of tensors - I updated it to check if a `FunctionalTensorMode` is active, which should be a 100% reliable way to know if AOTDispatcher is in the middle of running.

There is a long-standing followup here around unifying `torch.compiler.is_compiling()` to work in all cases. We should probably just update it to always check if FakeMode/FunctionalMode are active and use it there. This has a bit of BC risk though so I opted for the more local fix to SAC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152195
Approved by: https://github.com/soulitzer
2025-05-05 17:21:00 +00:00
Aaron Orenstein
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
Ramana Sundararaman
be4b7e8131 Param fixes in docstring (#136097)
Fixes wrong param names in docstrings. cc: @kit1980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136097
Approved by: https://github.com/ezyang
2024-09-21 18:56:34 +00:00
soulitzer
a23dae22d5 Update AC pass use_reentrant message (#134472)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134472
Approved by: https://github.com/albanD
2024-08-26 21:57:38 +00:00
Li Yu (ads)
94155ce31b [Torch] Support meta device in checkpoint (#132684)
Summary:
## Why
utils.checkpoint doesn't support meta device:

```
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 490, in checkpoint
    next(gen)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 1359, in _checkpoint_without_reentrant_generator
    device_module = _get_device_module(device)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 98, in _get_device_module
    device_module = getattr(torch, device)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/__init__.py", line 1938, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'meta'
```

This blocks us from running model with checkpoint enabled in meta mode.

## What
This diff handles the case of meta device in checkpoint.py.

(in checkpoint.py, device module is manily used when preserve_rng_state=true, which doesn't apply to meta case. So a more elgant fix might be set preserve_rng_state=false when detecting args are on meta device. But I didn't find where to do this check in the minimum way. Let me know if you have ideas.)

Test Plan: Tested with toy model which has checkpoint on its module: P1513716944

Differential Revision: D60749427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132684
Approved by: https://github.com/kit1980
2024-08-06 20:45:50 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
soulitzer
eeef68671d [autograd] Do not detach when unpacking tensors that do not require grad (#127959)
In this PR:
- Ensure that if a tensor not requiring grad is saved for backward unpacking does not trigger a detach (unless the user installs a saved tensor pack hook that returns a tensor requiring grad).
- Update non-reentrant checkpoint to also no longer detach for this case.

Alternatives:
- For custom autograd Function, you could directly save on ctx to work around this, but that would not work for when we switch to using custom ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127959
Approved by: https://github.com/YuqingJ
ghstack dependencies: #125795, #128545, #129262
2024-07-01 21:57:36 +00:00
Dmitry Rogozhkin
321bdcb372 Fix device propagation for checkpointing (#128671)
Fixes: #128478

In backward() implementation checkpointing code was quering device type from the rng_state tensors saved on forward(). These tensors are CPU only tensors and don't carry device information with them. As a result CUDA device was assumed as a default. Which is not correct if user runs on some other device. For example, on XPU.

This patch saves full device information on forward() and uses it on backward() to get device type. Previously forward save only device index.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128671
Approved by: https://github.com/guangyey, https://github.com/soulitzer
2024-06-27 17:14:13 +00:00
Will Feng
575bc1e3af [Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) (#129295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129295
Approved by: https://github.com/Chillee
2024-06-25 23:47:08 +00:00
soulitzer
c89a9f5d17 Allow SAC policy_fn to return bool for backward compatibility (#129262)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129262
Approved by: https://github.com/Chillee, https://github.com/fmassa
ghstack dependencies: #125795, #128545
2024-06-24 13:54:30 +00:00
soulitzer
1877b7896c [checkpoint] Clean up selective activation checkpoint and make public (#125795)
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230))  Enum instead of bool.
   - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
   - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE :  metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-18 18:18:50 +00:00
PyTorch MergeBot
6895a5804c Revert "[checkpoint] Clean up selective activation checkpoint and make public (#125795)"
This reverts commit c472cec565.

Reverted https://github.com/pytorch/pytorch/pull/125795 on behalf of https://github.com/soulitzer due to breaking torchtitan CI ([comment](https://github.com/pytorch/pytorch/pull/125795#issuecomment-2167036157))
2024-06-14 01:14:59 +00:00
soulitzer
c472cec565 [checkpoint] Clean up selective activation checkpoint and make public (#125795)
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-12 23:57:33 +00:00
Andrew Gu
8c1247cffb [BE] Fixed CPU autocast warning (#127774)
This PR fixes
```
/data/users/andgu/pytorch/torch/utils/checkpoint.py:1398: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127774
Approved by: https://github.com/soulitzer, https://github.com/Skylion007, https://github.com/tianyu-l
2024-06-11 21:33:35 +00:00
Aaron Orenstein
8db9dfa2d7 Flip default value for mypy disallow_untyped_defs [9/11] (#127846)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127846
Approved by: https://github.com/ezyang
ghstack dependencies: #127842, #127843, #127844, #127845
2024-06-08 18:50:06 +00:00
Wanchao Liang
d937d0db0f [SAC] fix ignored ops in eager mode to recompute (#126751)
as titled. I found that there're some issues in the eager mode SAC where
sometimes we would have recompute pop from storage of ops that are
missing, these ops are detach ops. So this PR refactors the two modes,
so that they would always recompute ignored ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126751
Approved by: https://github.com/yf225
2024-05-22 06:47:22 +00:00
Yu, Guangye
d17be10df1 make torch.amp.autocast more generic (#125103)
# Motivation
As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
2024-05-08 12:13:26 +00:00
soulitzer
fab5bd5359 [checkpoint] Improve error message when use_reentrant=True is used with .grad() (#125155)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125155
Approved by: https://github.com/albanD
2024-04-29 18:57:35 +00:00
Yu, Guangye
19a83eacb5 add new API torch.amp.is_autocast_available (#124938)
# Motivation
expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938
Approved by: https://github.com/albanD
2024-04-26 08:45:20 +00:00
Yu, Guangye
cdc66e9dc3 refactor autocast python APIs (#124479)
# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
2024-04-25 14:33:33 +00:00
Yu, Guangye
25f321b84f Refactor autocast C++ APIs to be device-agnostic (#124359)
# Motivation
This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast  **C++** APIs.
In C++ side,
- `is_enabled()` -> `is_enabled(device_type)`.
- `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`.
- `get_autocast_dtype()` -> `get_autocast_dtype(device_type)`
- `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)`

These following C++ APIs are deprecated and should be removed in PyTorch 2.5
- `is_cpu_enabled`
- `set_cpu_enabled`
- `get_autocast_cpu_dtype`
- `set_autocast_cpu_dtype`
- `is_xpu_enabled`
- `set_xpu_enabled`
- `get_autocast_xpu_dtype`
- `set_autocast_xpu_dtype`
- `is_ipu_enabled`
- `set_ipu_enabled`
- `get_autocast_ipu_dtype`
- `set_autocast_ipu_dtype`
- `is_hpu_enabled`
- `set_hpu_enabled`
- `get_autocast_hpu_dtype`
- `set_autocast_hpu_dtype`
- `is_xla_enabled`
- `set_xla_enabled`
- `get_autocast_xla_dtype`
- `set_autocast_xla_dtype`
- `is_privateuseone_enabled`
- `set_privateuseone_enabled`
- `get_autocast_privateuseone_dtype`
- `set_autocast_privateuseone_dtype`

In Python side,
provide 4 generic autocast APIs:
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type, new_enabled)`
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type, dtype)`

# Additional Context
We will submit another PR to refactor autocast **Python** APIs based on this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359
Approved by: https://github.com/jgong5, https://github.com/albanD
2024-04-23 10:38:50 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Edward Z. Yang
26a9b05bce Set stacklevel on checkpoint warning (#123717)
Partially addresses https://github.com/pytorch/pytorch/issues/123626

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123717
Approved by: https://github.com/Skylion007
2024-04-10 17:25:06 +00:00
Andrew Gu
1d6fc0d4de Fixed _infer_device_type warning in checkpoint (#122726)
Previously, we were checking `len(device_types)` where `device_types` is a `list`. This meant that if there were multiple inputs, we would see something like `device_types = ["cuda", "cuda"]` and a false positive warning. We should check `len(set(device_types))`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122726
Approved by: https://github.com/soulitzer
2024-03-27 18:38:42 +00:00
Pritam Damania
512251c8f3 Use tree_map to get device ids and device types for activation checkpointing (#121462)
`get_device_states` doesn't recursively look into nested lists/dicts to find tensors. As a result, activation checkpointing for such inputs results in silent incorrect results as `get_device_states` returns an empty result and no rng is saved as a result here: https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L188 since `fwd_device_states` is empty.

Fixed this by using `tree_map` for both `get_device_states` and `_infer_device_type`. Also added appropriate unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121462
Approved by: https://github.com/soulitzer
2024-03-20 21:09:21 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
soulitzer
5866284d4a Make not passing use_reentrant back to warning instead of erroring and clarify docs (#116710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116710
Approved by: https://github.com/albanD
ghstack dependencies: #116523
2024-01-09 20:58:49 +00:00
Ghassene Jebali
e728ebb66d Small docstring fix (#116947)
Fix a small typo in the docstring of checkpoint function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116947
Approved by: https://github.com/Skylion007, https://github.com/kit1980
2024-01-08 23:51:59 +00:00
soulitzer
4d6a1ad400 Activation checkpoint and checkpoint_sequential errors if use_reentrant not passed explicitly (#115868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115868
Approved by: https://github.com/albanD
ghstack dependencies: #115438
2023-12-20 15:23:44 +00:00
Wanchao Liang
dd367b7c8f check tensor subclass when using torch.compile + SAC (#115960)
as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
2023-12-18 17:49:06 +00:00
Will Feng
495054545c Allow preserve_rng_state=True when torch.compile + selective checkpointing + CUDA (#113718)
Fixes https://github.com/pytorch/pytorch/issues/113717.

When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.

But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
2023-12-09 01:47:25 +00:00
soulitzer
a7bcc78bff Make it clearer that current selective AC is PT2-only and private (#115081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115081
Approved by: https://github.com/albanD
2023-12-04 23:01:22 +00:00
soulitzer
c1d9d4a2b5 checkpoint_sequential warns if use_reentrant not passed explicitly (#114158)
Use warning text for deprecation message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114158
Approved by: https://github.com/albanD
2023-11-20 23:08:44 +00:00
ChanBong
5e10dd2c78 fix docstring issues in torch.utils (#113335)
Fixes #112634

Fixes all the issues listed except in `torch/utils/_pytree.py` as the file no longer exists.

### Error counts

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/utils/collect_env.py` | 39 | 25|
|`torch/utils/cpp_extension.py` | 51 | 13|
|`torch/utils/flop_counter.py` | 25 | 8|
|`torch/utils/_foreach_utils.py.py` | 2 | 0|
|`torch/utils/_python_dispatch.py.py` | 26 | 25|
|`torch/utils/backend_registration.py` | 15 | 4|
|`torch/utils/checkpoint.py` | 29 | 21|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113335
Approved by: https://github.com/ezyang
2023-11-13 19:37:25 +00:00
Brian Hirsh
7064fbf1ea Fix selective activation checkpointing with subclasses that override sizes() (#113380)
The problem is that we have a subclass (FunctionalTensor) that overrides size/stride calls, causing them to go through __torch_dispatch__.

But when SAC is active, we have _CachingTorchDispatchMode.__torch_dispatch__ active, that intercepts those size/stride calls first, and does something different with them instead of letting FunctionalTensor.__torch_dispatch__ handle them.

This PR updates the SAC torch dispatch mode to know to not handle metadata calls, and let its tensor arguments handle them directly.

Right now, `FunctionalTensor` has a hardcoded list of metadata ops, but we should probably put them somewhere more general.

I'll add better testing before landing this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113380
Approved by: https://github.com/yf225, https://github.com/wanchaol
2023-11-10 04:12:50 +00:00
soulitzer
c9eb8d8d90 Add set_checkpoint_debug_enabled that overrides local setting (#110728)
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728
Approved by: https://github.com/albanD
ghstack dependencies: #110673, #110674, #110675, #110676
2023-10-11 02:12:31 +00:00
Brian Hirsh
b457e3f79a Reland attempt 2 of "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)" (#110079)
The first reland broke internal (failing diff: D49617462).

The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.

Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.

This reverts commit 1b90f07f5a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
2023-10-03 18:50:25 +00:00
PyTorch MergeBot
1b90f07f5a Revert "Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)"
This reverts commit d0fe8fa5db.

Reverted https://github.com/pytorch/pytorch/pull/109906 on behalf of https://github.com/atalman due to Breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/109906#issuecomment-1735416852))
2023-09-26 12:10:25 +00:00
Brian Hirsh
d0fe8fa5db Reland "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)
I'm pretty sure this is fixed but I'll run inductor and trunk CI. The failing test in trunk previously was that the selective activation checkpointing code that landed recently assumes that it can detect whether or not AOTAutograd is running by seeing if the inputs to SAC are C++ `FunctionalTensorWrapper`s

previous land broke some inductor trunk tests

This reverts commit 629a628cc8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109906
Approved by: https://github.com/ezyang
2023-09-25 14:53:54 +00:00
Will Feng
3f3e353885 torch.compile + selective activation checkpointing (#105489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489

NOTE: this PR is tagged "not user facing", because it's not ready to be announced externally yet.

This PR implements torch.compile + selective activation checkpoint (SAC) integration, by using `TagActivationCheckpoint` (same backend as torch.compile + full activation checkpoint integration).

TorchDispatchMode based implementation cannot support including inplace ops in the checkpointed region at the moment (the reason for this needs investigation), and there is also no way to ban them (because TorchDispatchMode now only sees "after-functionalization" ops, so can't detect if an op is in-place). Hence we hide torch.compile + SAC behind a flag (`torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint`) and will only use it internally for cases that are known to not have in-place ops. This state won't last too long, because in-place op will at least be able to be detected after Brian's mode reordering and related functionalization changes.
So next steps after this PR:
1. Wait for Brian's mode reordering and related functionalization changes to land, and then try to enable the "inplace ops" unit test for torch.compile + selective activation checkpoint (if it doesn't work, investigate why).
2. Unify selective- and full-checkpoint under TorchDispatchMode based implementation.

Differential Revision: D47497145

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
Approved by: https://github.com/anijain2305
2023-09-21 16:24:11 +00:00