mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add ignored_states to FSDP/fully_shard (#102056)
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056 Approved by: https://github.com/awgu
This commit is contained in:
parent
023bc30b17
commit
956bd03808
|
|
@ -5,7 +5,9 @@ import sys
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
||||
from torch.distributed._composable import fully_shard
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
CUDAInitMode,
|
||||
|
|
@ -101,12 +103,29 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
"""Tests that ignored modules' parameters are not flattened for a
|
||||
transformer model with shared parameters."""
|
||||
self.run_subtests(
|
||||
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
|
||||
{
|
||||
"use_orig_params": [False, True],
|
||||
"ignore_modules": [True, False],
|
||||
"composable": [False],
|
||||
},
|
||||
self._test_ignored_modules_transformer,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_transformer_composable(self):
|
||||
"""Tests that ignored modules' parameters are not flattened for a
|
||||
transformer model with shared parameters."""
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_orig_params": [True],
|
||||
"ignore_modules": [True, False],
|
||||
"composable": [True],
|
||||
},
|
||||
self._test_ignored_modules_transformer,
|
||||
)
|
||||
|
||||
def _test_ignored_modules_transformer(
|
||||
self, use_orig_params: bool, ignore_modules: bool
|
||||
self, use_orig_params: bool, ignore_modules: bool, composable: bool
|
||||
):
|
||||
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
|
||||
# the `nn.Transformer` module's parameters
|
||||
|
|
@ -117,19 +136,35 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
deterministic=True,
|
||||
)
|
||||
if ignore_modules:
|
||||
wrapped_model = FSDP(
|
||||
wrapped_model = (
|
||||
FSDP(
|
||||
model,
|
||||
self.process_group,
|
||||
ignored_modules=[model.transformer],
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
if not composable
|
||||
else fully_shard(
|
||||
model,
|
||||
process_group=self.process_group,
|
||||
ignored_modules=[model.transformer],
|
||||
)
|
||||
)
|
||||
else:
|
||||
wrapped_model = FSDP(
|
||||
wrapped_model = (
|
||||
FSDP(
|
||||
model,
|
||||
self.process_group,
|
||||
ignored_parameters=list(model.transformer.parameters()),
|
||||
ignored_states=list(model.transformer.parameters()),
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
if not composable
|
||||
else fully_shard(
|
||||
model,
|
||||
process_group=self.process_group,
|
||||
ignored_states=list(model.transformer.parameters()),
|
||||
)
|
||||
)
|
||||
# Check that the wrapped model's flattened parameter does not include
|
||||
# the ignored transformer module's parameters
|
||||
nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
|
||||
|
|
@ -144,9 +179,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
)
|
||||
nonignored_numel = total_numel - ignored_numel
|
||||
with FSDP.summon_full_params(wrapped_model):
|
||||
flat_param = wrapped_model.params[0]
|
||||
flat_param = (
|
||||
wrapped_model.params[0]
|
||||
if not composable
|
||||
else _get_module_fsdp_state(wrapped_model).params[0]
|
||||
)
|
||||
flat_param_numel = flat_param.numel()
|
||||
if use_orig_params:
|
||||
if composable or use_orig_params:
|
||||
# Subtract the numel contributed from alignment padding
|
||||
padding_numel = sum(
|
||||
numel
|
||||
|
|
@ -166,26 +205,62 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
"""Tests that passing a module with nested FSDP modules does not
|
||||
error and still ignores non-FSDP modules' parameters."""
|
||||
self.run_subtests(
|
||||
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
|
||||
{
|
||||
"use_orig_params": [False, True],
|
||||
"ignore_modules": [True, False],
|
||||
"composable": [False],
|
||||
},
|
||||
self._test_ignored_modules_nested,
|
||||
)
|
||||
|
||||
def _test_ignored_modules_nested(self, use_orig_params: bool, ignore_modules: bool):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_nested_composable(self):
|
||||
"""Tests that passing a module with nested FSDP modules does not
|
||||
error and still ignores non-FSDP modules' parameters."""
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_orig_params": [True],
|
||||
"ignore_modules": [True, False],
|
||||
"composable": [True],
|
||||
},
|
||||
self._test_ignored_modules_nested,
|
||||
)
|
||||
|
||||
def _test_ignored_modules_nested(
|
||||
self, use_orig_params: bool, ignore_modules: bool, composable: bool
|
||||
):
|
||||
# Initialize an FSDP-wrapped nested model that first wraps the nested
|
||||
# sequential's second linear layer (`layer1[1]`) and then wraps the
|
||||
# overall model while ignoring the nested sequential (`layer1`)
|
||||
model = Model().cuda()
|
||||
model.layer1[1] = FSDP(model.layer1[1], use_orig_params=use_orig_params)
|
||||
model.layer1[1] = (
|
||||
FSDP(model.layer1[1], use_orig_params=use_orig_params)
|
||||
if not composable
|
||||
else fully_shard(model.layer1[1])
|
||||
)
|
||||
if ignore_modules:
|
||||
wrapped_model = FSDP(
|
||||
model, ignored_modules=[model.layer1], use_orig_params=use_orig_params
|
||||
wrapped_model = (
|
||||
FSDP(
|
||||
model,
|
||||
ignored_modules=[model.layer1],
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
if not composable
|
||||
else fully_shard(model, ignored_modules=[model.layer1])
|
||||
)
|
||||
else:
|
||||
wrapped_model = FSDP(
|
||||
wrapped_model = (
|
||||
FSDP(
|
||||
model,
|
||||
ignored_parameters=list(model.layer1.parameters()),
|
||||
ignored_states=[model.layer1],
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
if not composable
|
||||
else fully_shard(
|
||||
model,
|
||||
ignored_states=[model.layer1],
|
||||
)
|
||||
)
|
||||
# Check that the wrapped model's flattened parameter does not include
|
||||
# the ignored nested sequential's parameters
|
||||
nonwrapped_model = Model()
|
||||
|
|
@ -193,9 +268,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters())
|
||||
nonignored_numel = total_numel - ignored_numel
|
||||
with FSDP.summon_full_params(wrapped_model):
|
||||
flat_param = wrapped_model.params[0]
|
||||
flat_param = (
|
||||
wrapped_model.params[0]
|
||||
if not composable
|
||||
else _get_module_fsdp_state(wrapped_model).params[0]
|
||||
)
|
||||
flat_param_numel = flat_param.numel()
|
||||
if use_orig_params:
|
||||
if composable or use_orig_params:
|
||||
# Subtract the numel contributed from alignment padding
|
||||
padding_numel = sum(
|
||||
numel
|
||||
|
|
@ -212,24 +291,29 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
self._train_model(wrapped_model, optim, 3)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_invalid(self):
|
||||
@parametrize("composable", [True, False])
|
||||
def test_ignored_modules_invalid(self, composable):
|
||||
"""Tests that passing an FSDP module as an ignored module or the
|
||||
top-level module itself errors."""
|
||||
model = Model().cuda()
|
||||
model.layer1 = FSDP(model.layer1)
|
||||
wrap_cls = FSDP if composable else fully_shard
|
||||
model.layer1 = wrap_cls(model.layer1)
|
||||
# Passing an FSDP module as an ignored module should error
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg="`ignored_modules` should not include FSDP modules",
|
||||
):
|
||||
FSDP(model, ignored_modules=[model.layer1])
|
||||
wrap_cls(model, ignored_modules=[model.layer1])
|
||||
with self.assertWarnsRegex(
|
||||
expected_warning=UserWarning,
|
||||
expected_regex="Trying to ignore the top-level module passed into "
|
||||
"the FSDP constructor itself will result in all parameters being "
|
||||
"ignored",
|
||||
):
|
||||
FSDP(model, ignored_modules=[model])
|
||||
# `fully_shard` does not allow to wrap the same model twice, so create
|
||||
# a new local model here.
|
||||
new_model = Model().cuda()
|
||||
wrap_cls(new_model, ignored_modules=[new_model])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_diff_ignored_modules_across_ranks(self):
|
||||
|
|
@ -247,16 +331,21 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{
|
||||
"pass_ignored_modules_to_root": [False, True],
|
||||
"ignore_modules": [True, False],
|
||||
"composable": [True, False],
|
||||
},
|
||||
self._test_diff_ignored_modules_across_ranks,
|
||||
)
|
||||
|
||||
def _test_diff_ignored_modules_across_ranks(
|
||||
self, pass_ignored_modules_to_root: bool, ignore_modules: bool
|
||||
self,
|
||||
pass_ignored_modules_to_root: bool,
|
||||
ignore_modules: bool,
|
||||
composable: bool,
|
||||
):
|
||||
# To exercise different `FlatParameter` enumerations across ranks,
|
||||
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
|
||||
# after `layer1`, which has the variable number of ignored modules
|
||||
wrap_cls = FSDP if composable else fully_shard
|
||||
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
|
||||
layer1_ignored_modules = [
|
||||
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
|
||||
|
|
@ -265,13 +354,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": layer1_ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": {
|
||||
"ignored_states": {
|
||||
p for m in layer1_ignored_modules for p in m.parameters()
|
||||
}
|
||||
}
|
||||
)
|
||||
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
|
||||
model.layer3 = FSDP(model.layer3)
|
||||
model.layer1 = wrap_cls(model.layer1, **ignore_kwargs)
|
||||
model.layer3 = wrap_cls(model.layer3)
|
||||
model_ignored_modules = (
|
||||
[m for m in model.modules() if isinstance(m, IgnoredModule)]
|
||||
if pass_ignored_modules_to_root
|
||||
|
|
@ -281,18 +370,21 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": model_ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": {
|
||||
"ignored_states": {
|
||||
p for m in model_ignored_modules for p in m.parameters()
|
||||
}
|
||||
}
|
||||
)
|
||||
wrapped_model = FSDP(model, **ignore_kwargs_top)
|
||||
wrapped_model = wrap_cls(model, **ignore_kwargs_top)
|
||||
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
||||
self._train_model(wrapped_model, optim, 3)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("ignore_modules", [True, False])
|
||||
def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool):
|
||||
@parametrize("composable", [True, False])
|
||||
def test_ignored_modules_not_under_wrapped_root(
|
||||
self, ignore_modules: bool, composable: bool
|
||||
):
|
||||
model = Model().cuda()
|
||||
ignored_modules = list(model.layer1.children())[1:]
|
||||
|
||||
|
|
@ -300,19 +392,17 @@ class TestFSDPIgnoredModules(FSDPTest):
|
|||
{"ignored_modules": ignored_modules}
|
||||
if ignore_modules
|
||||
else {
|
||||
"ignored_parameters": {
|
||||
p for m in ignored_modules for p in m.parameters()
|
||||
}
|
||||
"ignored_states": {p for m in ignored_modules for p in m.parameters()}
|
||||
}
|
||||
)
|
||||
|
||||
model.layer1 = FSDP(
|
||||
wrap_cls = FSDP if composable else fully_shard
|
||||
|
||||
model.layer1 = wrap_cls(
|
||||
model.layer1,
|
||||
# sharding_strategy shouldn't matter here.
|
||||
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
||||
**ignore_kwargs,
|
||||
)
|
||||
model.layer3 = FSDP(
|
||||
model.layer3 = wrap_cls(
|
||||
model.layer3,
|
||||
# the ignored modules/parameters contains submodule under model.layer1, which
|
||||
# is out of the local root model.layer3.
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ def fully_shard(
|
|||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
ignored_states: Union[
|
||||
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||
] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
||||
|
|
@ -56,7 +59,7 @@ def fully_shard(
|
|||
if policy is not None and not isinstance(policy, _FSDPPolicy):
|
||||
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
|
||||
state = fully_shard.state(module)
|
||||
state = _init_ignored_module_states(state, module, ignored_modules)
|
||||
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
|
||||
state = _init_device_handle(state, module, state._ignored_params, device_id)
|
||||
state = _init_process_group_state(
|
||||
state, process_group, ShardingStrategy.FULL_SHARD, policy
|
||||
|
|
@ -96,7 +99,7 @@ def fully_shard(
|
|||
_register_root_pre_forward_hook(state, module) # prepend last
|
||||
for submodule in module.modules():
|
||||
if (
|
||||
submodule not in state._ignored_modules
|
||||
submodule in state._fully_sharded_module_to_handles
|
||||
and _get_module_state(submodule) is None
|
||||
):
|
||||
_insert_module_state(submodule, state)
|
||||
|
|
|
|||
|
|
@ -245,12 +245,21 @@ def _init_ignored_module_states(
|
|||
state: _FSDPState,
|
||||
module: nn.Module,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]],
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
ignored_states: Union[
|
||||
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||
] = None,
|
||||
) -> _FSDPState:
|
||||
assert (
|
||||
ignored_modules is None or ignored_parameters is None
|
||||
), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \
|
||||
Please either pass `ignored_modules` or `ignored_parameters`."
|
||||
ignored_modules is None or ignored_states is None
|
||||
), "Can not pass `ignored_modules` and `ignored_states` at the same time. \
|
||||
Please either pass `ignored_modules` or `ignored_states`."
|
||||
ignored_parameters = None
|
||||
if ignored_states:
|
||||
ignored_states_set = set(ignored_states)
|
||||
if isinstance(next(iter(ignored_states), None), torch.nn.Parameter):
|
||||
ignored_parameters = ignored_states_set
|
||||
else:
|
||||
ignored_modules = ignored_states_set
|
||||
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
|
||||
state._ignored_params = _get_ignored_params(
|
||||
module,
|
||||
|
|
@ -659,7 +668,7 @@ def _get_ignored_modules(
|
|||
for module in ignored_root_modules:
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
||||
if isinstance(module, fsdp_file.FullyShardedDataParallel):
|
||||
if _get_module_fsdp_state(module):
|
||||
# TODO: We may relax this by taking the FSDP instance's wrapped
|
||||
# module to provide more flexibility to the user.
|
||||
raise ValueError("`ignored_modules` should not include FSDP modules")
|
||||
|
|
@ -716,7 +725,7 @@ def _get_ignored_params(
|
|||
}
|
||||
all_ignored_params.update(params_in_ignored_parameters)
|
||||
|
||||
# Include nested FSDP modules' ignored parameters
|
||||
# Always include nested FSDP modules' ignored parameters
|
||||
for submodule in root_module.modules():
|
||||
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
||||
if optional_fsdp_state is not None:
|
||||
|
|
|
|||
|
|
@ -385,11 +385,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = False,
|
||||
use_orig_params: bool = False,
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
ignored_states: Union[
|
||||
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||
] = None,
|
||||
):
|
||||
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
||||
super().__init__()
|
||||
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
|
||||
_init_ignored_module_states(self, module, ignored_modules, ignored_states)
|
||||
_init_device_handle(self, module, self._ignored_params, device_id)
|
||||
|
||||
# Add module annotations for Dynamo support (see function for details)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user