mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls. This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923 Approved by: https://github.com/mrshenli
56 lines
2.2 KiB
Python
56 lines
2.2 KiB
Python
import functools
|
|
import warnings
|
|
from typing import Any, Dict
|
|
|
|
from torch.distributed.fsdp._utils import (
|
|
_contains_batchnorm,
|
|
_override_batchnorm_mixed_precision,
|
|
)
|
|
from torch.distributed.fsdp.wrap import (
|
|
_or_policy,
|
|
_recursive_wrap,
|
|
_wrap_batchnorm_individually,
|
|
)
|
|
|
|
|
|
def _auto_wrap(
|
|
auto_wrap_kwargs: Dict[str, Any],
|
|
fsdp_kwargs: Dict[str, Any],
|
|
module_wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
|
|
) -> None:
|
|
"""
|
|
Recursively auto wraps the root module given by the key "module" in
|
|
``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
|
|
``fsdp_kwargs``.
|
|
|
|
Precondition: ``auto_wrap_policy`` contains the arguments expected by
|
|
``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
|
|
``fsdp_kwargs`` contains all FSDP arguments except ``module``.
|
|
"""
|
|
auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
|
|
root_module = auto_wrap_kwargs["module"]
|
|
assert auto_wrap_policy is not None
|
|
# For auto wrapping, submodules should not already be wrapped with FSDP
|
|
# since double wrapping is not supported
|
|
for module_name, module in root_module.named_modules():
|
|
if isinstance(module, module_wrapper_cls):
|
|
raise ValueError(
|
|
f"Expected {module_name} to NOT be FullyShardedDataParallel "
|
|
"if using an `auto_wrap_policy`"
|
|
)
|
|
mixed_precision = fsdp_kwargs["mixed_precision"]
|
|
if mixed_precision is not None and _contains_batchnorm(root_module):
|
|
_override_batchnorm_mixed_precision(root_module)
|
|
auto_wrap_policy = functools.partial(
|
|
_or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy]
|
|
)
|
|
warnings.warn(
|
|
"Both mixed precision and an `auto_wrap_policy` were specified "
|
|
"for FSDP, where the wrapped module has batch norm submodules. "
|
|
"The batch norm submodules will be wrapped as separate FSDP "
|
|
"instances with mixed precision disabled since some batch norm "
|
|
"kernels do not support low precision."
|
|
)
|
|
auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
|
|
_recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)
|