pytorch/torch/distributed/_composable/checkpoint_activation.py
Chien-Chin Huang d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00

254 lines
9.3 KiB
Python

import torch
import torch.nn as nn
from torch.utils.checkpoint import detach_variable
from contextlib import contextmanager
from functools import partial
from typing import Any, List, Optional, Tuple
from weakref import ReferenceType, WeakKeyDictionary, ref
from .contract import contract
@contextmanager
def _no_hook(module: nn.Module):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
except Exception:
raise
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
class _ModuleHookCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module: nn.Module, output: Any, *inputs: Any) -> Any: # type: ignore[override]
ctx.module = module
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, inp in enumerate(inputs):
if torch.is_tensor(inp):
tensor_inputs.append(inp)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(inp)
ctx.save_for_backward(*tensor_inputs)
return output
@staticmethod
def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an "
"`inputs` parameter is passed to .backward(). Please use "
".backward() and do not pass its `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), _no_hook(ctx.module):
outputs = ctx.module(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(output_grads, torch.Tensor):
output_grads = (output_grads,)
# run backward() with only tensor that requires grad
outputs_requires_grad: List[torch.Tensor] = []
output_grad_tensors: List[torch.Tensor] = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_requires_grad.append(outputs[i])
assert (
output_grads[i] is not None
), f"expecting grad for output at index {i}, but got None."
output_grad_tensors.append(output_grads[i]) # type: ignore[arg-type]
if len(outputs_requires_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_requires_grad, output_grad_tensors)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
# The two None is for forward argument module and output respectively.
return (None, None) + grads
class _Holder:
pass
def _pack(
x: torch.Tensor,
*,
weak_holder_list: List[ReferenceType],
) -> _Holder:
res = _Holder()
weak_holder_list.append(ref(res))
return res
def _unpack(
holder: _Holder,
*,
storage: WeakKeyDictionary,
weak_holder_list: List[ReferenceType],
module: nn.Module,
inputs: Tuple[Any],
) -> torch.Tensor:
holder_index = 0
if len(storage) == 0:
def inner_pack(inner: torch.Tensor):
nonlocal holder_index
if weak_holder_list[holder_index]() is None:
# If the holder went out of scope, the SavedVariable is dead
# and so the value will never be read from the storage. Skip
# filling it.
pass
else:
# Use detach here to ensure we don't keep the temporary
# autograd graph created during the second forward
storage[weak_holder_list[holder_index]()] = inner.detach()
holder_index += 1
return
def inner_unpack(holder: _Holder):
raise RuntimeError(
"You are calling backwards on a tensor that is never exposed. "
"Please open an issue."
)
with _no_hook(
module
), torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = module(*inputs)
if holder not in storage:
raise RuntimeError(
"Attempt to retrieve a tensor saved by autograd multiple times "
"without checkpoint recomputation being triggered in between, this "
"is not currently supported. Please open an issue with details on "
"your use case so that we can prioritize adding this."
)
return storage[holder]
@contract()
def checkpoint(module: nn.Module, *, use_reentrant: bool = True) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
use_reentrant (bool): Apply activation checkpointing using reentrant
autograd.
Example::
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
def forward_pre_hook(module: nn.Module, inputs: Tuple[Any]) -> None:
if checkpoint.state(module).enable_hook:
checkpoint.state(module).orig_grad_enabled = torch.is_grad_enabled()
if checkpoint.state(module).use_reentrant:
torch.set_grad_enabled(False)
else:
# The Holder object for each of the saved object is saved
# directly on the SavedVariable and is cleared when reset_data()
# is called on it. We MUST make sure that this is the only
# object having an owning reference to ensure that the Tensor
# stored in storage is deleted as soon as the corresponding
# SavedVariable data is cleared.
storage: WeakKeyDictionary = WeakKeyDictionary()
weak_holder_list: List[ReferenceType] = []
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
partial(_pack, weak_holder_list=weak_holder_list),
partial(
_unpack,
storage=storage,
weak_holder_list=weak_holder_list,
module=module,
inputs=inputs,
),
)
saved_tensor_hooks.__enter__()
checkpoint.state(module).saved_tensor_hooks = saved_tensor_hooks
def forward_hook(module: nn.Module, inputs: Tuple[Any], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
torch.set_grad_enabled(checkpoint.state(module).orig_grad_enabled)
if checkpoint.state(module).use_reentrant:
return _ModuleHookCheckpointFunction.apply(
module, output, *inputs
)
else:
checkpoint.state(module).saved_tensor_hooks.__exit__()
checkpoint.state(module).saved_tensor_hooks = None
return output
# This hook does the following things:
# 1. detach outputs from the autograd graph to discard activations
# 2. insert an autograd.Function after the forward pass to recompute
# activations during the backward pass.
checkpoint.state(module).enable_hook = True
checkpoint.state(module).use_reentrant = use_reentrant
module.register_forward_pre_hook(forward_pre_hook)
# Use prepend to make sure we restore the original grad enabled state right
# after the module forward invocation.
module.register_forward_hook(forward_hook, prepend=True)
return module