mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728 Approved by: https://github.com/albanD ghstack dependencies: #110673, #110674, #110675, #110676
38 lines
2.2 KiB
ReStructuredText
38 lines
2.2 KiB
ReStructuredText
torch.utils.checkpoint
|
|
======================
|
|
|
|
.. note::
|
|
Checkpointing is implemented by rerunning a forward-pass segment for
|
|
each checkpointed segment during backward. This can cause persistent
|
|
states like the RNG state to be advanced than they would without
|
|
checkpointing. By default, checkpointing includes logic to juggle
|
|
the RNG state such that checkpointed passes making use of RNG
|
|
(through dropout for example) have deterministic output as
|
|
compared to non-checkpointed passes. The logic to stash and restore
|
|
RNG states can incur a moderate performance hit depending on the runtime
|
|
of checkpointed operations. If deterministic output compared to
|
|
non-checkpointed passes is not required, supply ``preserve_rng_state=False``
|
|
to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and
|
|
restoring the RNG state during each checkpoint.
|
|
|
|
The stashing logic saves and restores the RNG state for CPU and another
|
|
device type (infer the device type from Tensor arguments excluding CPU
|
|
tensors by ``_infer_device_type``) to the ``run_fn``. If there are multiple
|
|
device, device state will only be saved for devices of a single device type,
|
|
and the remaining devices will be ignored. Consequently, if any checkpointed
|
|
functions involve randomness, this may result in incorrect gradients. (Note
|
|
that if CUDA devices are among the devices detected, it will be prioritized;
|
|
otherwise, the first device encountered will be selected.) If there are no
|
|
CPU-tensors, the default device type state (default value is `cuda`, and it
|
|
could be set to other device by ``DefaultDeviceType``) will be saved and restored.
|
|
However, the logic has no way to anticipate if the user will move
|
|
Tensors to a new device within the ``run_fn`` itself. Therefore, if you move
|
|
Tensors to a new device ("new" meaning not belonging to the set of
|
|
[current device + devices of Tensor arguments]) within ``run_fn``, deterministic
|
|
output compared to non-checkpointed passes is never guaranteed.
|
|
|
|
.. currentmodule:: torch.utils.checkpoint
|
|
.. autofunction:: checkpoint
|
|
.. autofunction:: checkpoint_sequential
|
|
.. autofunction:: set_checkpoint_debug_enabled
|