pytorch/torch/distributed/_composable/fully_shard.py
wz337 791f8ef350 [Composable APIs] Add composable API fully_shard deprecation warning (#120929)
`fully_shard`(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/fully_shard.py) will be used by new FSDP2 and we want to add a deprecation warning to the existing composable API's `fully_shard`(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fully_shard.py#L40).

Planned release schedule is as follows https://dev-discuss.pytorch.org/t/release-cadence-for-year-2023-2024/1557:

Minor Version | Release branch cut | Release date | First patch release date | Second patch release date
-- | -- | -- | -- | --
2.3 | Mar 2024 | Apr 2024 | May 2024 | Jun 2024
2.4 | May 2024 | Jul 2024 | Aug 2024 | Sep 2024
2.5 | Aug 2024 | Oct 2024 | Nov 2024 | Dec 2024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120929
Approved by: https://github.com/awgu
2024-03-01 00:55:16 +00:00

134 lines
5.1 KiB
Python

import warnings
from typing import Callable, Iterable, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import contract
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hook,
_register_pre_forward_hook,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _Policy
@contract(state_cls=_FSDPState)
def fully_shard(
module: nn.Module,
*,
process_group: Optional[dist.ProcessGroup] = None,
policy: Optional[_Policy] = None,
strategy: Optional[ShardingStrategy] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
device_id: Optional[Union[int, torch.device]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> nn.Module:
"""
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
"""
warnings.warn(
"``torch.distributed._composable.fully_shard`` is being deprecated."
"You can contintue to use the wrapper based FSDP."
"See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py."
"``torch.distributed._composable.fully_shard`` will be removed after PyTorch 2.5."
)
torch._C._log_api_usage_once("torch.distributed.fully_shard")
# Enforce the new auto wrap policy
if policy is not None and not isinstance(policy, _Policy):
raise ValueError(f"Expects a `_Policy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
state = _init_device_handle(state, module, state._ignored_params, device_id)
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
state = _init_process_group_state(state, process_group, strategy, policy)
if policy is not None:
root_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
root_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
_register_pre_forward_hook(state, module)
_register_post_forward_hook(state, module)
_register_root_pre_forward_hook(state, module) # prepend last
# Always insert the state for the passed-in module even if it has no
# managed parameters, in which case it has no handles and does not appear
# in `_fully_sharded_module_to_handles`
_insert_module_state(module, state)
for submodule in module.modules():
if (
submodule in state._fully_sharded_module_to_handle
and _get_module_state(submodule) is None
):
_insert_module_state(submodule, state)
return module