mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.
The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root
---
After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
from typing import Callable, Iterable, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed._composable.contract import contract
|
|
from torch.distributed._composable_state import _get_module_state, _insert_module_state
|
|
from torch.distributed.fsdp._common_utils import _FSDPState
|
|
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
|
|
|
|
from torch.distributed.fsdp._init_utils import (
|
|
_init_buffer_state,
|
|
_init_core_state,
|
|
_init_device_handle,
|
|
_init_ignored_module_states,
|
|
_init_param_handle_from_module,
|
|
_init_prefetching_state,
|
|
_init_process_group_state,
|
|
_init_runtime_state,
|
|
_init_state_dict_state,
|
|
HYBRID_SHARDING_STRATEGIES,
|
|
)
|
|
from torch.distributed.fsdp._runtime_utils import (
|
|
_register_post_forward_hook,
|
|
_register_pre_forward_hook,
|
|
_register_root_pre_forward_hook,
|
|
)
|
|
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
|
|
from torch.distributed.fsdp._wrap_utils import _auto_wrap
|
|
from torch.distributed.fsdp.api import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
MixedPrecision,
|
|
ShardingStrategy,
|
|
)
|
|
from torch.distributed.fsdp.wrap import _Policy
|
|
|
|
|
|
@contract(state_cls=_FSDPState)
|
|
def fully_shard(
|
|
module: nn.Module,
|
|
*,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
policy: Optional[_Policy] = None,
|
|
strategy: Optional[ShardingStrategy] = None,
|
|
mixed_precision: Optional[MixedPrecision] = None,
|
|
cpu_offload: Optional[CPUOffload] = None,
|
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
|
device_id: Optional[Union[int, torch.device]] = None,
|
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
|
sync_module_states: bool = False,
|
|
forward_prefetch: bool = False,
|
|
ignored_states: Union[
|
|
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
|
] = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.fully_shard")
|
|
# Enforce the new auto wrap policy
|
|
if policy is not None and not isinstance(policy, _Policy):
|
|
raise ValueError(f"Expects a `_Policy` but got {policy}")
|
|
state = fully_shard.state(module)
|
|
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
|
|
state = _init_device_handle(state, module, state._ignored_params, device_id)
|
|
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
|
|
state = _init_process_group_state(state, process_group, strategy, policy)
|
|
if policy is not None:
|
|
root_kwargs = {
|
|
"process_group": process_group,
|
|
"strategy": strategy,
|
|
"mixed_precision": mixed_precision,
|
|
"cpu_offload": cpu_offload,
|
|
"ignored_modules": ignored_modules,
|
|
"device_id": device_id,
|
|
"param_init_fn": param_init_fn,
|
|
"sync_module_states": sync_module_states,
|
|
"forward_prefetch": forward_prefetch,
|
|
"ignored_states": ignored_states,
|
|
}
|
|
if strategy in HYBRID_SHARDING_STRATEGIES:
|
|
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
|
|
_auto_wrap(
|
|
module,
|
|
policy,
|
|
state._ignored_modules,
|
|
state._ignored_params,
|
|
root_kwargs,
|
|
fully_shard,
|
|
)
|
|
state = _init_core_state(
|
|
state,
|
|
strategy or ShardingStrategy.FULL_SHARD,
|
|
mixed_precision,
|
|
cpu_offload,
|
|
limit_all_gathers=True,
|
|
use_orig_params=True,
|
|
backward_prefetch_limit=1,
|
|
forward_prefetch_limit=1,
|
|
)
|
|
state = _init_runtime_state(state)
|
|
state = _init_prefetching_state(
|
|
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
|
|
)
|
|
state = _init_buffer_state(state, module)
|
|
state = _init_param_handle_from_module(
|
|
state, module, device_id, param_init_fn, sync_module_states
|
|
)
|
|
state = _init_state_dict_state(state)
|
|
_register_all_state_dict_hooks(state)
|
|
_register_pre_forward_hook(state, module)
|
|
_register_post_forward_hook(state, module)
|
|
_register_root_pre_forward_hook(state, module) # prepend last
|
|
# Always insert the state for the passed-in module even if it has no
|
|
# managed parameters, in which case it has no handles and does not appear
|
|
# in `_fully_sharded_module_to_handles`
|
|
_insert_module_state(module, state)
|
|
for submodule in module.modules():
|
|
if (
|
|
submodule in state._fully_sharded_module_to_handle
|
|
and _get_module_state(submodule) is None
|
|
):
|
|
_insert_module_state(submodule, state)
|
|
return module
|