Summary:
Firstly, this does not change any existing behaviour, since all the
default values for kwargs were hardcoded into the ``_checkpoint_without_reentrant_generator`` call.
Secondly, this is needed for unlocking the full potential of composable
checkpointing making it equivalent to ``torch.utils.checkpoint.checkpoint(use_reentrant=False)``.
Finally, an added benefit is now composable checkpointing can be used under ``FakeTensorMode`` by
passing ``preserve_rng_state=False``.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128516
Approved by: https://github.com/awgu
Fixes#91654.
Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.
The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.
---
Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`
```python
from typing import Any, Dict, Tuple
import torch
from torch import nn
def forward_pre_hook(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
) -> None:
...
def forward_pre_hook_return_input(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
...
def forward_pre_hook_with_kwargs(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
) -> None:
...
def forward_pre_hook_with_kwargs_return_input(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
...
def forward_hook(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
output: torch.Tensor,
) -> None:
...
def forward_hook_return_output(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
output: torch.Tensor,
) -> torch.Tensor:
...
def forward_hook_with_kwargs(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
output: torch.Tensor,
) -> None:
...
def forward_hook_with_kwargs_return_output(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
output: torch.Tensor,
) -> torch.Tensor:
...
model = nn.Module()
# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)
model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)
# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)
model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)
# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```
---
Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):
```python
T = TypeVar("T", bound="Module")
class Module:
@overload
def register_forward_hook(
self,
hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
*,
prepend: bool = ...,
with_kwargs: Literal[False] = ...,
) -> RemovableHandle:
...
@overload
def register_forward_hook(
self,
hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
*,
prepend: bool = ...,
with_kwargs: Literal[True] = ...,
) -> RemovableHandle:
...
def register_forward_hook(...):
...
```
which would:
1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061
Approved by: https://github.com/albanD
**Why this PR?**
For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.
It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.
**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing
model source code. Unlike ``nn.Module`` wrapper activation checkpointing
APIs, this one does not modify model structure or fully-qualified names
either. Under the hood, it registers activation checkpointing logic as pre-
and post-forward hooks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87664
Approved by: https://github.com/zhaojuanmao