pytorch/docs/source/autograd.md
soulitzer b3861ac8e7
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
[reland] Warn if AccumulateGrad stream does not match producer node stream (#166136)
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065

Fixes #ISSUE_NUMBER

Opening a new PR for codev

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136
Approved by: https://github.com/ngimel
2025-11-01 12:33:48 +00:00

475 lines
12 KiB
Markdown

```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# Automatic differentiation package - torch.autograd
```{eval-rst}
.. automodule:: torch.autograd
```
```{eval-rst}
.. currentmodule:: torch.autograd
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
backward
grad
```
(forward-mode-ad)=
## Forward-mode Automatic Differentiation
:::{warning}
This API is in beta. Even though the function signatures are very unlikely to change, improved
operator coverage is planned before we consider this stable.
:::
Please see the [forward-mode AD tutorial](https://pytorch.org/tutorials/intermediate/forward_ad_usage.html)
for detailed steps on how to use this API.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
forward_ad.dual_level
forward_ad.make_dual
forward_ad.unpack_dual
forward_ad.enter_dual_level
forward_ad.exit_dual_level
forward_ad.UnpackedDualTensor
```
(functional-api)=
## Functional higher level API
:::{warning}
This API is in beta. Even though the function signatures are very unlikely to change, major
improvements to performances are planned before we consider this stable.
:::
This section contains the higher level API for the autograd that builds on the basic API above
and allows you to compute jacobians, hessians, etc.
This API works with user-provided functions that take only Tensors as input and return
only Tensors.
If your function takes other arguments that are not Tensors or Tensors that don't have requires_grad set,
you can use a lambda to capture them.
For example, for a function `f` that takes three inputs, a Tensor for which we want the jacobian, another
tensor that should be considered constant and a boolean flag as `f(input, constant, flag=flag)`
you can use it as `functional.jacobian(lambda x: f(x, constant, flag=flag), input)`.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
functional.jacobian
functional.hessian
functional.vjp
functional.jvp
functional.vhp
functional.hvp
```
(locally-disable-grad)=
## Locally disabling gradient computation
See {ref}`locally-disable-grad-doc` for more information on the differences
between no-grad and inference mode as well as other related mechanisms that
may be confused with the two. Also see {ref}`torch-rst-local-disable-grad`
for a list of functions that can be used to locally disable gradients.
(default-grad-layouts)=
## Default gradient layouts
When a non-sparse `param` receives a non-sparse gradient during
{func}`torch.autograd.backward` or {func}`torch.Tensor.backward`
`param.grad` is accumulated as follows.
If `param.grad` is initially `None`:
1. If `param`'s memory is non-overlapping and dense, `.grad` is
created with strides matching `param` (thus matching `param`'s
layout).
2. Otherwise, `.grad` is created with rowmajor-contiguous strides.
If `param` already has a non-sparse `.grad` attribute:
3. If `create_graph=False`, `backward()` accumulates into `.grad`
in-place, which preserves its strides.
4. If `create_graph=True`, `backward()` replaces `.grad` with a
new tensor `.grad + new grad`, which attempts (but does not guarantee)
matching the preexisting `.grad`'s strides.
The default behavior (letting `.grad`s be `None` before the first
`backward()`, such that their layout is created according to 1 or 2,
and retained over time according to 3 or 4) is recommended for best performance.
Calls to `model.zero_grad()` or `optimizer.zero_grad()` will not affect `.grad`
layouts.
In fact, resetting all `.grad`s to `None` before each
accumulation phase, e.g.:
```
for iterations...
...
for param in model.parameters():
param.grad = None
loss.backward()
```
such that they're recreated according to 1 or 2 every time,
is a valid alternative to `model.zero_grad()` or `optimizer.zero_grad()`
that may improve performance for some networks.
### Manual gradient layouts
If you need manual control over `.grad`'s strides,
assign `param.grad =` a zeroed tensor with desired strides
before the first `backward()`, and never reset it to `None`.
3 guarantees your layout is preserved as long as `create_graph=False`.
4 indicates your layout is *likely* preserved even if `create_graph=True`.
## In-place operations on Tensors
Supporting in-place operations in autograd is a hard matter, and we discourage
their use in most cases. Autograd's aggressive buffer freeing and reuse makes
it very efficient and there are very few occasions when in-place operations
actually lower memory usage by any significant amount. Unless you're operating
under heavy memory pressure, you might never need to use them.
### In-place correctness checks
All {class}`Tensor` s keep track of in-place operations applied to them, and
if the implementation detects that a tensor was saved for backward in one of
the functions, but it was modified in-place afterwards, an error will be raised
once backward pass is started. This ensures that if you're using in-place
functions and not seeing any errors, you can be sure that the computed
gradients are correct.
## Variable (deprecated)
:::{warning}
The Variable API has been deprecated: Variables are no longer necessary to
use autograd with tensors. Autograd automatically supports Tensors with
`requires_grad` set to `True`. Below please find a quick guide on what
has changed:
- `Variable(tensor)` and `Variable(tensor, requires_grad)` still work as expected,
but they return Tensors instead of Variables.
- `var.data` is the same thing as `tensor.data`.
- Methods such as `var.backward(), var.detach(), var.register_hook()` now work on tensors
with the same method names.
In addition, one can now create tensors with `requires_grad=True` using factory
methods such as {func}`torch.randn`, {func}`torch.zeros`, {func}`torch.ones`, and others
like the following:
`autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)`
:::
## Tensor autograd functions
```{eval-rst}
.. autosummary::
:nosignatures:
torch.Tensor.grad
torch.Tensor.requires_grad
torch.Tensor.is_leaf
torch.Tensor.backward
torch.Tensor.detach
torch.Tensor.detach_
torch.Tensor.register_hook
torch.Tensor.register_post_accumulate_grad_hook
torch.Tensor.retain_grad
```
## {hidden}`Function`
```{eval-rst}
.. autoclass:: Function
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
Function.forward
Function.backward
Function.jvp
Function.vmap
```
(context-method-mixins)=
## Context method mixins
When creating a new {class}`Function`, the following methods are available to `ctx`.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
function.FunctionCtx.mark_dirty
function.FunctionCtx.mark_non_differentiable
function.FunctionCtx.save_for_backward
function.FunctionCtx.set_materialize_grads
```
## Custom Function utilities
Decorator for backward method.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
function.once_differentiable
```
Base custom {class}`Function` used to build PyTorch utilities
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
function.BackwardCFunction
function.InplaceFunction
function.NestedIOFunction
```
(grad-check)=
## Numerical gradient checking
```{eval-rst}
.. automodule:: torch.autograd.gradcheck
```
```{eval-rst}
.. currentmodule:: torch.autograd.gradcheck
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
gradcheck
gradgradcheck
GradcheckError
```
% Just to reset the base path for the rest of this file
```{eval-rst}
.. currentmodule:: torch.autograd
```
## Profiler
Autograd includes a profiler that lets you inspect the cost of different
operators inside your model - both on the CPU and GPU. There are three modes
implemented at the moment - CPU-only using {class}`~torch.autograd.profiler.profile`.
nvprof based (registers both CPU and GPU activity) using
{class}`~torch.autograd.profiler.emit_nvtx`.
and vtune profiler based using
{class}`~torch.autograd.profiler.emit_itt`.
```{eval-rst}
.. autoclass:: torch.autograd.profiler.profile
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
profiler.profile.export_chrome_trace
profiler.profile.key_averages
profiler.profile.self_cpu_time_total
profiler.profile.total_average
profiler.parse_nvprof_trace
profiler.EnforceUnique
profiler.KinetoStepTracker
profiler.record_function
profiler_util.Interval
profiler_util.Kernel
profiler_util.MemRecordsAcc
profiler_util.StringTable
```
```{eval-rst}
.. autoclass:: torch.autograd.profiler.emit_nvtx
```
```{eval-rst}
.. autoclass:: torch.autograd.profiler.emit_itt
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
profiler.load_nvprof
```
## Debugging and anomaly detection
```{eval-rst}
.. autoclass:: detect_anomaly
```
```{eval-rst}
.. autoclass:: set_detect_anomaly
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
grad_mode.set_multithreading_enabled
```
## Autograd graph
Autograd exposes methods that allow one to inspect the graph and interpose behavior during
the backward pass.
The `grad_fn` attribute of a {class}`torch.Tensor` holds a {class}`torch.autograd.graph.Node`
if the tensor is the output of a operation that was recorded by autograd (i.e., grad_mode is
enabled and at least one of the inputs required gradients), or `None` otherwise.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
graph.Node.name
graph.Node.metadata
graph.Node.next_functions
graph.Node.register_hook
graph.Node.register_prehook
graph.increment_version
```
Some operations need intermediary results to be saved during the forward pass
in order to execute the backward pass.
These intermediary results are saved as attributes on the `grad_fn` and can be accessed.
For example:
```
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.exp()
>>> print(isinstance(b.grad_fn, torch.autograd.graph.Node))
True
>>> print(dir(b.grad_fn))
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_raw_saved_result', '_register_hook_dict', '_saved_result', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad']
>>> print(torch.allclose(b.grad_fn._saved_result, b))
True
```
You can also define how these saved tensors should be packed / unpacked using hooks.
A common application is to trade compute for memory by saving those intermediary results
to disk or to CPU instead of leaving them on the GPU. This is especially useful if you
notice your model fits on GPU during evaluation, but not training.
Also see {ref}`saved-tensors-hooks-doc`.
```{eval-rst}
.. autoclass:: torch.autograd.graph.saved_tensors_hooks
```
```{eval-rst}
.. autoclass:: torch.autograd.graph.save_on_cpu
```
```{eval-rst}
.. autoclass:: torch.autograd.graph.disable_saved_tensors_hooks
```
```{eval-rst}
.. autoclass:: torch.autograd.graph.register_multi_grad_hook
```
```{eval-rst}
.. autoclass:: torch.autograd.graph.allow_mutation_on_saved_tensors
```
```{eval-rst}
.. autoclass:: torch.autograd.graph.GradientEdge
```
```{eval-rst}
.. autofunction:: torch.autograd.graph.get_gradient_edge
```
```{eval-rst}
.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
```{eval-rst}
.. py:module:: torch.autograd.anomaly_mode
```
```{eval-rst}
.. py:module:: torch.autograd.forward_ad
```
```{eval-rst}
.. py:module:: torch.autograd.function
```
```{eval-rst}
.. py:module:: torch.autograd.functional
```
```{eval-rst}
.. py:module:: torch.autograd.grad_mode
```
```{eval-rst}
.. py:module:: torch.autograd.graph
```
```{eval-rst}
.. py:module:: torch.autograd.profiler
```
```{eval-rst}
.. py:module:: torch.autograd.profiler_legacy
```
```{eval-rst}
.. py:module:: torch.autograd.profiler_util
```
```{eval-rst}
.. py:module:: torch.autograd.variable
```