mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
from typing import Callable, Iterable, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed._composable.contract import contract
|
|
from torch.distributed._composable_state import _get_module_state, _insert_module_state
|
|
from torch.distributed.fsdp._common_utils import _FSDPState
|
|
|
|
from torch.distributed.fsdp._init_utils import (
|
|
_init_buffer_state,
|
|
_init_core_state,
|
|
_init_ignored_module_states,
|
|
_init_param_handles_from_module,
|
|
_init_prefetching_state,
|
|
_init_process_group_state,
|
|
_init_runtime_state,
|
|
_init_state_dict_state,
|
|
)
|
|
from torch.distributed.fsdp._runtime_utils import (
|
|
_register_post_forward_hooks,
|
|
_register_pre_forward_hooks,
|
|
_register_root_pre_forward_hook,
|
|
)
|
|
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
|
|
from torch.distributed.fsdp.api import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
MixedPrecision,
|
|
ShardingStrategy,
|
|
)
|
|
from torch.distributed.fsdp.wrap import _FSDPPolicy
|
|
|
|
|
|
@contract(state_cls=_FSDPState)
|
|
def fully_shard(
|
|
module: nn.Module,
|
|
*,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
policy: Optional[_FSDPPolicy] = None,
|
|
strategy: Optional[ShardingStrategy] = None,
|
|
mixed_precision: Optional[MixedPrecision] = None,
|
|
cpu_offload: Optional[CPUOffload] = None,
|
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
|
device_id: Optional[Union[int, torch.device]] = None,
|
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
|
sync_module_states: bool = False,
|
|
) -> nn.Module:
|
|
"""
|
|
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
|
"""
|
|
# Enforce the new auto wrap policy
|
|
if policy is not None and not isinstance(policy, _FSDPPolicy):
|
|
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
|
|
state = fully_shard.state(module)
|
|
state = _init_ignored_module_states(state, module, ignored_modules)
|
|
state = _init_process_group_state(
|
|
state, process_group, ShardingStrategy.FULL_SHARD, policy
|
|
)
|
|
limit_all_gathers = True
|
|
use_orig_params = True
|
|
backward_prefetch_limit = 1
|
|
forward_prefetch_limit = 1
|
|
state = _init_core_state(
|
|
state,
|
|
strategy or ShardingStrategy.FULL_SHARD,
|
|
mixed_precision,
|
|
cpu_offload,
|
|
limit_all_gathers,
|
|
use_orig_params,
|
|
backward_prefetch_limit,
|
|
forward_prefetch_limit,
|
|
)
|
|
state = _init_runtime_state(state)
|
|
state = _init_prefetching_state(state, BackwardPrefetch.BACKWARD_PRE, False)
|
|
state = _init_buffer_state(state, module)
|
|
state = _init_param_handles_from_module(
|
|
state,
|
|
module,
|
|
policy,
|
|
device_id,
|
|
param_init_fn,
|
|
sync_module_states,
|
|
)
|
|
state = _init_state_dict_state(state)
|
|
_register_all_state_dict_hooks(state)
|
|
modules = list(module.modules())
|
|
_register_pre_forward_hooks(state, modules)
|
|
_register_post_forward_hooks(state, modules)
|
|
_register_root_pre_forward_hook(state, module) # prepend last
|
|
for submodule in module.modules():
|
|
if (
|
|
submodule not in state._ignored_modules
|
|
and _get_module_state(submodule) is None
|
|
):
|
|
_insert_module_state(submodule, state)
|
|
return module
|