pytorch/torch/distributed/_composable/contract.py
Chien-Chin Huang d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00

185 lines
7.0 KiB
Python

from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Type
import torch.nn as nn
from torch.distributed._composable_state import _State
# use state_slot as key for module.__dict__ to avoid coliding with other
# properties.
# TODO: since all composable distributed features can share the same slot.
class _StateKey:
# implement operator < to avoid breaking dir()
def __lt__(self, other: Any) -> bool:
return True if isinstance(other, str) else id(self) < id(other)
STATE_KEY = _StateKey()
REGISTRY_KEY = _StateKey()
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
pass
def contract(state_cls: Type[_State] = _State):
r"""
Decorate a function as a composable distributed API, where the first
argument of the function must be an :class:`nn.Module` instance. The
decorator verifies that the wrapped function does not modify parameter,
buffer or sub-module fully-qualified names (FQN).
When a function ``func`` is decorated by ``@contract()``, a
``.state(module: nn.Module)`` method will be installed to the decorated
function. Then you can retrieve and modify the state on a module by calling
``func.state(module)``.
Example::
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> @contract()
>>> def my_feature(module: nn.Module) -> nn.Module:
>>> my_feature.state(module).some_state = "any value"
>>> return module
>>>
>>> model = MyModel()
>>> my_feature(model.l1)
>>> assert my_feature.state(model.l1).some_state == "any value"
>>> my_feature(model.l2)
>>> model(torch.randn(2, 10)).sum().backward()
"""
def inner(func):
def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
# get existing global states
default_all_state: Dict[Callable, _State] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
assert isinstance(
all_state, dict
), "Distributed composable API states corrupted"
# get global registry
default_registry: Dict[str, RegistryItem] = OrderedDict()
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
assert isinstance(
registry, dict
), "Distributed composable API registry corrupted"
# make sure the API func has not been applied to the input module yet.
assert func not in all_state and func.__name__ not in registry, (
"Each distinct composable distributed API can only be applied to a "
f"module once. {func.__name__} has already been applied to the "
f"following module.\n{module}"
)
# install states specific to the wrapped ``func``
all_state.setdefault(func, state_cls())
# register ``func`` in the global registry by name
registry.setdefault(func.__name__, RegistryItem())
orig_named_params = OrderedDict(module.named_parameters())
orig_named_buffers = OrderedDict(
module.named_buffers(remove_duplicate=False)
)
orig_named_modules = OrderedDict(
module.named_modules(remove_duplicate=False)
)
updated = func(module, *args, **kwargs)
if updated is None:
updated = module
new_named_params = OrderedDict(updated.named_parameters())
new_named_buffers = OrderedDict(
updated.named_buffers(remove_duplicate=False)
)
new_named_modules = OrderedDict(
updated.named_modules(remove_duplicate=False)
)
assert isinstance(updated, nn.Module), (
"Output of composable distributed APIs must be either None or "
f"nn.Module, but got {type(updated)}"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str]):
if orig_fqns == new_fqns:
return
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
orig_only = orig_fqn_set - new_fqn_set
new_only = new_fqn_set - orig_fqn_set
if len(orig_only) or len(new_only):
raise RuntimeError(
"Composable distributed API implementations cannot modify "
"FQNs.\n"
f"Only in original FQNs: {orig_only},\n"
f"Only in new FQNs: {new_only}"
)
else:
raise RuntimeError(
"Composable distributed API implementations cannot modify "
"the order of FQNs.\n"
f"Original FQNs: {orig_only}\n"
f"New FQNs: {new_only}"
)
check_fqn(list(orig_named_params.keys()), list(new_named_params.keys()))
check_fqn(
list(orig_named_buffers.keys()), list(new_named_buffers.keys())
)
check_fqn(
list(orig_named_modules.keys()), list(new_named_modules.keys())
)
# TODO: a stricter verification should also reject changing module
# types and monkey-patching forward() method implementations.
# TODO: verify that installed distributed paradigms are compatible with
# each other.
return updated
def get_state(module: nn.Module) -> Optional[_State]:
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]
return wrapper
return inner
def _get_registry(module: nn.Module) -> Dict[str, RegistryItem]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name.
"""
default_registry: Dict[str, RegistryItem] = OrderedDict()
return module.__dict__.setdefault(REGISTRY_KEY, default_registry) # type: ignore[call-overload]