pytorch/torch/distributed/_composable/checkpoint_activation.py
Rohan Varma 5d70fe0165 [Composable] Use non-reentrant generator, remove reentrant (#105176)
Removes reentrant support for the composable checkpoint, as
non-reentrant is the recommended approach and we should use this when rolling
out composable checkpoint API.

Also removes the standalone implementation for non-reentrant and instead uses
the generator from below diff to reuse the original implemenetation.

Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 07:03:03 +00:00

95 lines
3.3 KiB
Python

from contextlib import contextmanager, nullcontext
from typing import Any, Tuple
import torch
import torch.nn as nn
from torch.utils.checkpoint import (
_checkpoint_without_reentrant_generator,
_DEFAULT_DETERMINISM_MODE,
)
from .contract import contract
@contextmanager
def _no_hook(module: nn.Module):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
@contract()
def checkpoint(module: nn.Module) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint")
def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
if checkpoint.state(module).enable_hook:
def context_fns():
return nullcontext(), _no_hook(module)
checkpoint.state(
module
)._ac_generator = _checkpoint_without_reentrant_generator(
module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *inputs
)
next(checkpoint.state(module)._ac_generator)
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
try:
next(checkpoint.state(module)._ac_generator)
except StopIteration:
pass
else:
raise RuntimeError(
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
)
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
# clear this even in the case of exception in fwd pass.
checkpoint.state(module)._ac_generator = None
checkpoint.state(module).enable_hook = True
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
return module