mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The rationale for this is that functorch doesn't work with saved variable hooks at the moment or checkpointing and we need some way to disable it. Concretely: - there's a context manager that does the disabling - this feature is disabled on a thread-local basis - one can set an error message or use the default error message that says the feature has been disabled Since it is thread local I needed to update ATen/ThreadLocalState. To make things nicer, this PR refactors all the "saved tensors hooks" related TLS things into a single struct. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/85553 Approved by: https://github.com/soulitzer
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
import torch
|
|
import contextlib
|
|
from typing import Callable, Any
|
|
|
|
|
|
class saved_tensors_hooks():
|
|
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
|
|
|
|
Use this context-manager to define how intermediary results of an operation
|
|
should be packed before saving, and unpacked on retrieval.
|
|
|
|
In that context, the ``pack_hook`` function will be called everytime an
|
|
operation saves a tensor for backward (this includes intermediary results
|
|
saved using
|
|
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
|
|
also those recorded by a PyTorch-defined operation). The output of
|
|
``pack_hook`` is then stored in the computation graph instead of the
|
|
original tensor.
|
|
|
|
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
|
|
namely when executing :func:`torch.Tensor.backward()` or
|
|
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
|
|
returned by ``pack_hook`` and should return a tensor which has the same
|
|
content as the original tensor (passed as input to the corresponding
|
|
``pack_hook``).
|
|
|
|
The hooks should have the following signatures:
|
|
|
|
pack_hook(tensor: Tensor) -> Any
|
|
|
|
unpack_hook(Any) -> Tensor
|
|
|
|
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
|
|
|
|
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
|
|
of value, size, dtype and device.
|
|
|
|
Example::
|
|
|
|
>>> def pack_hook(x):
|
|
... print("Packing", x)
|
|
... return x
|
|
>>>
|
|
>>> def unpack_hook(x):
|
|
... print("Unpacking", x)
|
|
... return x
|
|
>>>
|
|
>>> a = torch.ones(5, requires_grad=True)
|
|
>>> b = torch.ones(5, requires_grad=True) * 2
|
|
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
|
... y = a * b
|
|
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
|
|
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
|
|
>>> y.sum().backward()
|
|
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
|
|
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
|
|
|
|
.. warning ::
|
|
Performing an inplace operation on the input to either hooks may lead
|
|
to undefined behavior.
|
|
|
|
.. warning ::
|
|
Only one pair of hooks is allowed at a time. When recursively nesting this
|
|
context-manager, only the inner-most pair of hooks will be applied.
|
|
"""
|
|
def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
|
|
self.pack_hook = pack_hook
|
|
self.unpack_hook = unpack_hook
|
|
|
|
def __enter__(self):
|
|
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
|
|
|
|
def __exit__(self, *args: Any):
|
|
torch._C._autograd._pop_saved_tensors_default_hooks()
|
|
|
|
|
|
class save_on_cpu(saved_tensors_hooks):
|
|
"""Context-manager under which tensors saved by the forward pass will be
|
|
stored on cpu, then retrieved for backward.
|
|
|
|
When performing operations within this context manager, intermediary
|
|
results saved in the graph during the forward pass will be moved to CPU,
|
|
then copied back to the original device when needed for the backward pass.
|
|
If the graph was already on CPU, no tensor copy is performed.
|
|
|
|
Use this context-manager to trade compute for GPU memory usage (e.g.
|
|
when your model doesn't fit in GPU memory during training).
|
|
|
|
Args:
|
|
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
|
|
during packing and copied to GPU asynchronously during unpacking.
|
|
Defaults to ``False``.
|
|
Also see :ref:`cuda-memory-pinning`.
|
|
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> a = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>> b = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>> c = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>>
|
|
>>> def f(a, b, c):
|
|
... prod_1 = a * b # a and b are saved on GPU
|
|
... with torch.autograd.graph.save_on_cpu():
|
|
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
|
|
... y = prod_2 * a # prod_2 and a are saved on GPU
|
|
... return y
|
|
>>>
|
|
>>> y = f(a, b, c)
|
|
>>> del a, b, c # for illustration only
|
|
>>> # the content of a, b, and prod_2 are still alive on GPU
|
|
>>> # the content of prod_1 and c only live on CPU
|
|
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
|
|
>>> # all intermediary tensors are released (deleted) after the call to backward
|
|
|
|
"""
|
|
def __init__(self, pin_memory=False):
|
|
def pack_to_cpu(tensor):
|
|
if not pin_memory:
|
|
return (tensor.device, tensor.cpu())
|
|
|
|
packed = torch.empty(
|
|
tensor.size(),
|
|
dtype=tensor.dtype,
|
|
layout=tensor.layout,
|
|
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
|
|
packed.copy_(tensor)
|
|
return (tensor.device, packed)
|
|
|
|
def unpack_from_cpu(packed):
|
|
device, tensor = packed
|
|
return tensor.to(device, non_blocking=pin_memory)
|
|
|
|
super().__init__(pack_to_cpu, unpack_from_cpu)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _disable_saved_tensors_hooks(error_message):
|
|
try:
|
|
maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
|
|
torch._C._autograd._saved_tensors_hooks_disable(error_message)
|
|
yield
|
|
finally:
|
|
# See NOTE: [disabled_error_message invariant]
|
|
if maybe_prev_message is None:
|
|
torch._C._autograd._saved_tensors_hooks_enable()
|
|
else:
|
|
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
|