mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Removes reentrant support for the composable checkpoint, as non-reentrant is the recommended approach and we should use this when rolling out composable checkpoint API. Also removes the standalone implementation for non-reentrant and instead uses the generator from below diff to reuse the original implemenetation. Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176 Approved by: https://github.com/awgu, https://github.com/fegin
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
from contextlib import contextmanager, nullcontext
|
|
from typing import Any, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.checkpoint import (
|
|
_checkpoint_without_reentrant_generator,
|
|
_DEFAULT_DETERMINISM_MODE,
|
|
)
|
|
|
|
from .contract import contract
|
|
|
|
|
|
@contextmanager
|
|
def _no_hook(module: nn.Module):
|
|
r"""
|
|
Disable hooks installed by checkpoint to avoid unintentional recursion
|
|
during backward recomputation.
|
|
"""
|
|
orig_enable_hook = checkpoint.state(module).enable_hook
|
|
checkpoint.state(module).enable_hook = False
|
|
try:
|
|
yield
|
|
finally:
|
|
checkpoint.state(module).enable_hook = orig_enable_hook
|
|
|
|
|
|
@contract()
|
|
def checkpoint(module: nn.Module) -> nn.Module:
|
|
r"""
|
|
This is a composable activation checkpointing API. Unlike functional
|
|
activation checkpointing APIs, this one does not require changing model
|
|
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
|
|
this one does not modify model structure or fully-qualified names either.
|
|
Under the hood, it registers activation checkpointing logic as pre- and
|
|
post-forward hooks. Hence, this API can be easily applied to any model or
|
|
sub-modules in the model.
|
|
|
|
Args:
|
|
module (nn.Module): the target model or sub-module to apply activation
|
|
checkpointing.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.nn as nn
|
|
>>>
|
|
>>> class MyModel(nn.Module):
|
|
>>> def __init__(self):
|
|
>>> super().__init__()
|
|
>>> self.l1 = nn.Linear(10, 10)
|
|
>>> self.l2 = nn.Linear(10, 10)
|
|
>>>
|
|
>>> def forward(self, x):
|
|
>>> return self.l2(self.l1(x))
|
|
>>>
|
|
>>> model = MyModel()
|
|
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
|
|
>>> model(torch.zeros(2, 10)).sum().backward()
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint")
|
|
|
|
def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
|
|
if checkpoint.state(module).enable_hook:
|
|
|
|
def context_fns():
|
|
return nullcontext(), _no_hook(module)
|
|
|
|
checkpoint.state(
|
|
module
|
|
)._ac_generator = _checkpoint_without_reentrant_generator(
|
|
module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *inputs
|
|
)
|
|
next(checkpoint.state(module)._ac_generator)
|
|
|
|
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
|
|
if checkpoint.state(module).enable_hook:
|
|
try:
|
|
next(checkpoint.state(module)._ac_generator)
|
|
except StopIteration:
|
|
pass
|
|
else:
|
|
raise RuntimeError(
|
|
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
|
|
)
|
|
|
|
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
|
|
# clear this even in the case of exception in fwd pass.
|
|
checkpoint.state(module)._ac_generator = None
|
|
|
|
checkpoint.state(module).enable_hook = True
|
|
module.register_forward_pre_hook(forward_pre_hook)
|
|
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
|
|
return module
|