mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
After internal discussion, we are currently preferring `fully_shard()` as the name of the composable FSDP API.
- `FullyShardedDataParallel` (FSDP) has existing brand value, so the chosen name should try to preserve that. We think this takes precedence over the fact that composable FSDP may encompass than just the ZeRO-3 approach of _fully sharding_.
- Given the refactoring efforts, it would also not be challenging to create a new frontend API like `hybrid_shard()` that calls into the same underlying initialization and runtime except for a different `ShardingStrategy`. In other words, we do not have to coalesce all sharding strategies under `fully_shard()`.
- The other composable APIs are verbs (`replicate()`, `checkpoint()`), so the chosen name should be a verb.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88233
Approved by: https://github.com/mrshenli
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from typing import Callable, cast, Iterable, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed.fsdp._common_utils import _State, FSDPState
|
|
from torch.distributed.fsdp._init_utils import (
|
|
_init_buffer_state,
|
|
_init_core_state,
|
|
_init_ignored_module_states,
|
|
_init_param_handles_from_module,
|
|
_init_prefetching_state,
|
|
_init_process_group_state,
|
|
_init_runtime_state,
|
|
_init_state_dict_state,
|
|
)
|
|
from torch.distributed.fsdp._runtime_utils import (
|
|
_register_post_forward_hooks,
|
|
_register_pre_forward_hooks,
|
|
)
|
|
from torch.distributed.fsdp.api import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
MixedPrecision,
|
|
ShardingStrategy,
|
|
)
|
|
|
|
|
|
def fully_shard(
|
|
module: nn.Module,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
sharding_strategy: Optional[ShardingStrategy] = None,
|
|
mixed_precision: Optional[MixedPrecision] = None,
|
|
cpu_offload: Optional[CPUOffload] = None,
|
|
auto_wrap_policy: Optional[Callable] = None,
|
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
|
device_id: Optional[Union[int, torch.device]] = None,
|
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
|
sync_module_states: bool = False,
|
|
) -> FSDPState:
|
|
"""
|
|
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
|
"""
|
|
state = cast(_State, FSDPState())
|
|
state = _init_ignored_module_states(state, module, ignored_modules)
|
|
state = _init_process_group_state(state, process_group)
|
|
limit_all_gathers = True
|
|
use_orig_params = True
|
|
backward_prefetch_limit = 1
|
|
forward_prefetch_limit = 1
|
|
state = _init_core_state(
|
|
state,
|
|
sharding_strategy,
|
|
mixed_precision,
|
|
cpu_offload,
|
|
limit_all_gathers,
|
|
use_orig_params,
|
|
backward_prefetch_limit,
|
|
forward_prefetch_limit,
|
|
)
|
|
state = _init_runtime_state(state)
|
|
state = _init_prefetching_state(state, BackwardPrefetch.BACKWARD_PRE, False)
|
|
state = _init_buffer_state(state, module)
|
|
state = _init_param_handles_from_module(
|
|
state,
|
|
module,
|
|
auto_wrap_policy,
|
|
device_id,
|
|
param_init_fn,
|
|
sync_module_states,
|
|
)
|
|
state = _init_state_dict_state(state)
|
|
modules = list(module.modules())
|
|
_register_pre_forward_hooks(state, modules)
|
|
_register_post_forward_hooks(state, modules)
|
|
return cast(FSDPState, state)
|