import torch import torch.nn as nn from torch.utils.checkpoint import detach_variable from contextlib import contextmanager from functools import partial from typing import Any, List, Optional, Tuple from weakref import ReferenceType, WeakKeyDictionary, ref from .contract import contract @contextmanager def _no_hook(module: nn.Module): r""" Disable hooks installed by checkpoint to avoid unintentional recursion during backward recomputation. """ orig_enable_hook = checkpoint.state(module).enable_hook checkpoint.state(module).enable_hook = False try: yield except Exception: raise finally: checkpoint.state(module).enable_hook = orig_enable_hook class _ModuleHookCheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, module: nn.Module, output: Any, *inputs: Any) -> Any: # type: ignore[override] ctx.module = module # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): tensor_inputs.append(inp) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(inp) ctx.save_for_backward(*tensor_inputs) return output @staticmethod def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an " "`inputs` parameter is passed to .backward(). Please use " ".backward() and do not pass its `inputs` argument." ) # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) with torch.enable_grad(), _no_hook(ctx.module): outputs = ctx.module(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(output_grads, torch.Tensor): output_grads = (output_grads,) # run backward() with only tensor that requires grad outputs_requires_grad: List[torch.Tensor] = [] output_grad_tensors: List[torch.Tensor] = [] for i in range(len(outputs)): if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_requires_grad.append(outputs[i]) assert ( output_grads[i] is not None ), f"expecting grad for output at index {i}, but got None." output_grad_tensors.append(output_grads[i]) # type: ignore[arg-type] if len(outputs_requires_grad) == 0: raise RuntimeError( "none of output has requires_grad=True," " this checkpoint() is not necessary" ) torch.autograd.backward(outputs_requires_grad, output_grad_tensors) grads = tuple( inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs ) # The two None is for forward argument module and output respectively. return (None, None) + grads class _Holder: pass def _pack( x: torch.Tensor, *, weak_holder_list: List[ReferenceType], ) -> _Holder: res = _Holder() weak_holder_list.append(ref(res)) return res def _unpack( holder: _Holder, *, storage: WeakKeyDictionary, weak_holder_list: List[ReferenceType], module: nn.Module, inputs: Tuple[Any], ) -> torch.Tensor: holder_index = 0 if len(storage) == 0: def inner_pack(inner: torch.Tensor): nonlocal holder_index if weak_holder_list[holder_index]() is None: # If the holder went out of scope, the SavedVariable is dead # and so the value will never be read from the storage. Skip # filling it. pass else: # Use detach here to ensure we don't keep the temporary # autograd graph created during the second forward storage[weak_holder_list[holder_index]()] = inner.detach() holder_index += 1 return def inner_unpack(holder: _Holder): raise RuntimeError( "You are calling backwards on a tensor that is never exposed. " "Please open an issue." ) with _no_hook( module ), torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = module(*inputs) if holder not in storage: raise RuntimeError( "Attempt to retrieve a tensor saved by autograd multiple times " "without checkpoint recomputation being triggered in between, this " "is not currently supported. Please open an issue with details on " "your use case so that we can prioritize adding this." ) return storage[holder] @contract() def checkpoint(module: nn.Module, *, use_reentrant: bool = True) -> nn.Module: r""" This is a composable activation checkpointing API. Unlike functional activation checkpointing APIs, this one does not require changing model source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, this one does not modify model structure or fully-qualified names either. Under the hood, it registers activation checkpointing logic as pre- and post-forward hooks. Hence, this API can be easily applied to any model or sub-modules in the model. Args: module (nn.Module): the target model or sub-module to apply activation checkpointing. use_reentrant (bool): Apply activation checkpointing using reentrant autograd. Example:: >>> import torch.nn as nn >>> >>> class MyModel(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.l1 = nn.Linear(10, 10) >>> self.l2 = nn.Linear(10, 10) >>> >>> def forward(self, x): >>> return self.l2(self.l1(x)) >>> >>> model = MyModel() >>> checkpoint(model.l1) # apply activation checkpointing only to l1 >>> model(torch.zeros(2, 10)).sum().backward() """ def forward_pre_hook(module: nn.Module, inputs: Tuple[Any]) -> None: if checkpoint.state(module).enable_hook: checkpoint.state(module).orig_grad_enabled = torch.is_grad_enabled() if checkpoint.state(module).use_reentrant: torch.set_grad_enabled(False) else: # The Holder object for each of the saved object is saved # directly on the SavedVariable and is cleared when reset_data() # is called on it. We MUST make sure that this is the only # object having an owning reference to ensure that the Tensor # stored in storage is deleted as soon as the corresponding # SavedVariable data is cleared. storage: WeakKeyDictionary = WeakKeyDictionary() weak_holder_list: List[ReferenceType] = [] saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( partial(_pack, weak_holder_list=weak_holder_list), partial( _unpack, storage=storage, weak_holder_list=weak_holder_list, module=module, inputs=inputs, ), ) saved_tensor_hooks.__enter__() checkpoint.state(module).saved_tensor_hooks = saved_tensor_hooks def forward_hook(module: nn.Module, inputs: Tuple[Any], output: Any) -> Any: if checkpoint.state(module).enable_hook: torch.set_grad_enabled(checkpoint.state(module).orig_grad_enabled) if checkpoint.state(module).use_reentrant: return _ModuleHookCheckpointFunction.apply( module, output, *inputs ) else: checkpoint.state(module).saved_tensor_hooks.__exit__() checkpoint.state(module).saved_tensor_hooks = None return output # This hook does the following things: # 1. detach outputs from the autograd graph to discard activations # 2. insert an autograd.Function after the forward pass to recompute # activations during the backward pass. checkpoint.state(module).enable_hook = True checkpoint.state(module).use_reentrant = use_reentrant module.register_forward_pre_hook(forward_pre_hook) # Use prepend to make sure we restore the original grad enabled state right # after the module forward invocation. module.register_forward_hook(forward_hook, prepend=True) return module