pytorch/docs/source
soulitzer 1877b7896c [checkpoint] Clean up selective activation checkpoint and make public (#125795)
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230))  Enum instead of bool.
   - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
   - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE :  metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-18 18:18:50 +00:00
..
_static [docs] Update PT2+Profiler docs (#122272) 2024-03-28 17:52:28 +00:00
_templates Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
community [ONNX] Update 'person_of_interest.rst', 'CODEOWNERS' and 'merge_rules.yaml' (#126364) 2024-06-16 04:52:16 +00:00
elastic DOC: add docstring to construct_and_record_rdzv_event() (#128189) 2024-06-10 22:17:33 +00:00
notes Adding a note for Getting Started with PyTorch on Intel GPUs (#127872) 2024-06-14 14:24:28 +00:00
rpc
scripts [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126) 2024-05-27 14:49:57 +00:00
amp.rst generalize custom_fwd&custom_bwd to be device-agnostic (#126531) 2024-05-25 06:48:16 +00:00
autograd.rst Add torch.library.register_autograd (#124071) 2024-04-18 12:47:59 +00:00
backends.rst Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)" 2024-06-12 18:47:52 +00:00
benchmark_utils.rst Adding Compare in torch.utils.benchmark documentation (#125009) 2024-05-03 00:50:54 +00:00
bottleneck.rst
checkpoint.rst [checkpoint] Clean up selective activation checkpoint and make public (#125795) 2024-06-18 18:18:50 +00:00
complex_numbers.rst Document complex optimizer semantic behavior (#121667) 2024-03-16 00:43:47 +00:00
cond.rst Fix typo under docs directory (#119657) 2024-02-15 21:14:34 +00:00
conf.py Retire torch.distributed.pipeline (#127354) 2024-06-07 08:11:58 +00:00
config_mod.rst
cpp_extension.rst
cpp_index.rst
cpu.rst Add current_device() to torch.cpu (#110987) 2023-10-11 05:13:10 +00:00
cuda_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
cuda._sanitizer.rst
cuda.rst Created docs (and example) for cudart function in torch.cuda (#128741) 2024-06-17 16:50:37 +00:00
cuda.tunable.rst [ROCm] TunableOp improvements (#124362) 2024-06-03 22:30:11 +00:00
cudnn_persistent_rnn.rst
cudnn_rnn_determinism.rst
data.rst Revert "reseed all Generators in Dataloader's _worker_loop() -- via GC (#107131)" 2023-08-23 17:08:07 +00:00
ddp_comm_hooks.rst [DOCS][DDP]Fix the simple of saving and reloading PowerSGD state and hook. (#102721) 2023-06-10 00:15:00 +00:00
debugging_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
deploy.rst
deterministic.rst Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377) 2023-11-01 16:10:09 +00:00
distributed.algorithms.join.rst
distributed.checkpoint.rst [DCP] Provides default AsyncStager (#124939) 2024-05-02 19:48:54 +00:00
distributed.elastic.rst Reapply "distributed debug handlers (#126601)" (#127805) 2024-06-04 19:44:30 +00:00
distributed.optim.rst
distributed.pipelining.rst [pipelining][doc] Remove duplicated words (#128368) 2024-06-11 04:52:57 +00:00
distributed.rst Retire torch.distributed.pipeline (#127354) 2024-06-07 08:11:58 +00:00
distributed.tensor.parallel.rst [tp] doc fixes (#121431) 2024-03-08 17:46:44 +00:00
distributions.rst Add inverse gamma distribution and fix sign bug in PowerTransform. (#104501) 2023-11-01 02:26:25 +00:00
dlpack.rst
docutils.conf
export.ir_spec.rst [export] Remove torch._export.export (#119095) 2024-02-08 21:22:04 +00:00
export.rst [export] provide refine function for automatically accepting dynamic shapes suggested fixes (#127436) 2024-06-07 03:29:06 +00:00
fft.rst
fsdp.rst [FSDP][state_dict] Expose optimizer state_dict config (#105949) 2023-08-21 07:29:49 +00:00
func.api.rst
func.batch_norm.rst
func.migrating.rst
func.rst
func.ux_limitations.rst
func.whirlwind_tour.rst
future_mod.rst Add swap_tensors path to nn.Module._apply (#117167) 2024-02-07 18:55:44 +00:00
futures.rst
fx.experimental.rst Traceable wrapper subclass support for deferred runtime asserts (#126198) 2024-05-21 01:21:46 +00:00
fx.rst Implement Graph Transform Observer (#127427) 2024-06-02 06:49:47 +00:00
hub.rst
index.rst Retire torch.distributed.pipeline (#127354) 2024-06-07 08:11:58 +00:00
jit_builtin_functions.rst
jit_language_reference_v2.rst
jit_language_reference.rst
jit_python_reference.rst
jit_unsupported.rst Add support for torch.Generator type in TorchScript (#110413) 2023-11-21 23:07:21 +00:00
jit_utils.rst
jit.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
library.rst New Custom Ops Documentation landing page (#127400) 2024-05-30 01:06:04 +00:00
linalg.rst
logging.rst Change classification to beta for TORCH_LOGS (#118682) 2024-01-31 21:50:55 +00:00
masked.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
math-quantizer-equation.png
meta.rst Add documentation for meta device (#119119) 2024-02-04 01:05:22 +00:00
miscellaneous_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
mobile_optimizer.rst
model_zoo.rst
module_tracker.rst Add module tracker (#125352) 2024-05-04 18:33:35 +00:00
monitor.rst
mps.rst Add support in Python API for the recommended max working set size. (#128289) 2024-06-12 16:03:57 +00:00
mtia.rst [MTIA] Add set_device support (#128040) 2024-06-10 23:42:52 +00:00
multiprocessing.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
name_inference.rst [docs] Properly link register_post_accumulate_grad_hook docs (#108157) 2023-08-29 22:13:33 +00:00
named_tensor.rst fixing named tensor unflatten example (#106921) 2023-08-22 18:00:10 +00:00
nested.rst Replace master with main in links and docs/conf.py (#100176) 2023-05-02 18:20:32 +00:00
nn.attention.bias.rst Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
nn.attention.rst Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
nn.functional.rst Add RMSNorm module (#121364) 2024-03-29 18:05:28 +00:00
nn.init.rst
nn.rst Cleanup some duplicated placeholder py:module docs (#123244) 2024-04-05 03:18:53 +00:00
onnx_dynamo_onnxruntime_backend.rst Follow-up #108379 (#108905) 2023-09-09 01:38:36 +00:00
onnx_dynamo.rst [ez][doc] Fix sample code in onnx_dynamo.rst (#114770) 2023-11-29 19:27:52 +00:00
onnx_torchscript_supported_aten_ops.rst Refactor torch.onnx documentation (#108379) 2023-09-08 18:23:48 +00:00
onnx_torchscript.rst Follow-up #108379 (#108905) 2023-09-09 01:38:36 +00:00
onnx.rst fix pytorch version for onnx in doc (#124182) 2024-04-17 18:05:15 +00:00
optim.rst Added example regarding weight_decay distinction with per-parameter API (#117436) 2024-01-22 21:26:02 +00:00
package.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
profiler.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
quantization-accuracy-debugging.rst
quantization-backend-configuration.rst
quantization-support.rst [quant][pt2e] Add model_is_exported util function (#119726) 2024-02-16 19:29:36 +00:00
quantization.rst Cleanup some duplicated placeholder py:module docs (#123244) 2024-04-05 03:18:53 +00:00
random.rst
rpc.rst [BE] RPC is missing RRef docs (#106902) 2023-08-10 16:26:27 +00:00
signal.rst
size.rst Added a docstring for torch.Size.numel. (#124186) 2024-04-19 09:23:02 +00:00
sparse.rst Fix typo in sparse.rst (#121826) 2024-03-19 00:17:19 +00:00
special.rst
storage.rst
tensor_attributes.rst Include the scalar tensor auto-transfer in the doc (#119967) 2024-02-15 22:37:39 +00:00
tensor_view.rst
tensorboard.rst
tensors.rst add xpu to torch.tensors (#127280) 2024-06-11 18:13:01 +00:00
testing.rst
threading_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
torch_cuda_memory.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch_environment_variables.rst [c10d][doc] add a doc page for NCCL ENVs (#128235) 2024-06-09 16:08:38 +00:00
torch_nccl_environment_variables.rst [c10d][doc] add a doc page for NCCL ENVs (#128235) 2024-06-09 16:08:38 +00:00
torch.ao.ns._numeric_suite_fx.rst
torch.ao.ns._numeric_suite.rst
torch.compiler_aot_inductor.rst [AOTI] docs: add suggestion to turn on freezing on CPU (#128010) 2024-06-07 08:57:02 +00:00
torch.compiler_api.rst [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
torch.compiler_best_practices_for_backends.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_cudagraph_trees.rst [CUDAGraph] add more docs for cudagraph trees (#127963) 2024-06-18 02:07:07 +00:00
torch.compiler_custom_backends.rst Fix a link in the compiler backend doc (#126079) 2024-05-21 20:16:04 +00:00
torch.compiler_dynamic_shapes.rst feat: Add min, max ranges to mark_dynamic API (#119737) 2024-03-07 23:26:03 +00:00
torch.compiler_dynamo_deepdive.rst Fix links rendering when surrounding code in Dynamo deepdive (#123427) 2024-04-13 04:55:15 +00:00
torch.compiler_dynamo_overview.rst Rename TorchDynamo -> Dyanamo in the dynamo tutorial doc (#123431) 2024-05-07 05:07:00 +00:00
torch.compiler_fake_tensor.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_faq.rst Fixed broken link and removed unfinished sentence from issue #126367 (#127938) 2024-06-05 07:37:32 +00:00
torch.compiler_fine_grain_apis.rst [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
torch.compiler_get_started.rst add xpu to torch.compile (#127279) 2024-06-13 21:15:09 +00:00
torch.compiler_inductor_profiling.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_ir.rst [export] torch.export landing page (#108783) 2023-09-10 01:40:42 +00:00
torch.compiler_nn_module.rst Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964) 2023-10-11 05:16:47 +00:00
torch.compiler_performance_dashboard.rst Restructure torch.compile docs (#105376) 2023-07-28 20:58:57 +00:00
torch.compiler_profiling_torch_compile.rst [docs] Update PT2+Profiler docs (#122272) 2024-03-28 17:52:28 +00:00
torch.compiler_transformations.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch.compiler_troubleshooting.rst Add force_disable_caches to the docs (#126184) 2024-05-15 07:16:08 +00:00
torch.compiler.rst add xpu to torch.compile (#127279) 2024-06-13 21:15:09 +00:00
torch.overrides.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
torch.rst torch.mtia module for MTIA device backend (#123612) 2024-04-26 16:17:54 +00:00
type_info.rst
utils.rst New swap function (#111747) 2023-12-08 18:49:35 +00:00
xpu.rst [2/2] Intel GPU Runtime Upstreaming for Generator (#118613) 2024-02-28 05:28:11 +00:00