add ignored_states to FSDP/fully_shard (#102056)

Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
This commit is contained in:
Yanli Zhao 2023-05-23 22:35:13 +00:00 committed by PyTorch MergeBot
parent 023bc30b17
commit 956bd03808
4 changed files with 163 additions and 59 deletions

View File

@ -5,7 +5,9 @@ import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import distributed as dist from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import ( from torch.testing._internal.common_fsdp import (
CUDAInitMode, CUDAInitMode,
@ -101,12 +103,29 @@ class TestFSDPIgnoredModules(FSDPTest):
"""Tests that ignored modules' parameters are not flattened for a """Tests that ignored modules' parameters are not flattened for a
transformer model with shared parameters.""" transformer model with shared parameters."""
self.run_subtests( self.run_subtests(
{"use_orig_params": [False, True], "ignore_modules": [True, False]}, {
"use_orig_params": [False, True],
"ignore_modules": [True, False],
"composable": [False],
},
self._test_ignored_modules_transformer,
)
@skip_if_lt_x_gpu(2)
def test_ignored_modules_transformer_composable(self):
"""Tests that ignored modules' parameters are not flattened for a
transformer model with shared parameters."""
self.run_subtests(
{
"use_orig_params": [True],
"ignore_modules": [True, False],
"composable": [True],
},
self._test_ignored_modules_transformer, self._test_ignored_modules_transformer,
) )
def _test_ignored_modules_transformer( def _test_ignored_modules_transformer(
self, use_orig_params: bool, ignore_modules: bool self, use_orig_params: bool, ignore_modules: bool, composable: bool
): ):
# Initialize an FSDP-wrapped transformer model that has FSDP ignore # Initialize an FSDP-wrapped transformer model that has FSDP ignore
# the `nn.Transformer` module's parameters # the `nn.Transformer` module's parameters
@ -117,18 +136,34 @@ class TestFSDPIgnoredModules(FSDPTest):
deterministic=True, deterministic=True,
) )
if ignore_modules: if ignore_modules:
wrapped_model = FSDP( wrapped_model = (
model, FSDP(
self.process_group, model,
ignored_modules=[model.transformer], self.process_group,
use_orig_params=use_orig_params, ignored_modules=[model.transformer],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
process_group=self.process_group,
ignored_modules=[model.transformer],
)
) )
else: else:
wrapped_model = FSDP( wrapped_model = (
model, FSDP(
self.process_group, model,
ignored_parameters=list(model.transformer.parameters()), self.process_group,
use_orig_params=use_orig_params, ignored_states=list(model.transformer.parameters()),
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
process_group=self.process_group,
ignored_states=list(model.transformer.parameters()),
)
) )
# Check that the wrapped model's flattened parameter does not include # Check that the wrapped model's flattened parameter does not include
# the ignored transformer module's parameters # the ignored transformer module's parameters
@ -144,9 +179,13 @@ class TestFSDPIgnoredModules(FSDPTest):
) )
nonignored_numel = total_numel - ignored_numel nonignored_numel = total_numel - ignored_numel
with FSDP.summon_full_params(wrapped_model): with FSDP.summon_full_params(wrapped_model):
flat_param = wrapped_model.params[0] flat_param = (
wrapped_model.params[0]
if not composable
else _get_module_fsdp_state(wrapped_model).params[0]
)
flat_param_numel = flat_param.numel() flat_param_numel = flat_param.numel()
if use_orig_params: if composable or use_orig_params:
# Subtract the numel contributed from alignment padding # Subtract the numel contributed from alignment padding
padding_numel = sum( padding_numel = sum(
numel numel
@ -166,25 +205,61 @@ class TestFSDPIgnoredModules(FSDPTest):
"""Tests that passing a module with nested FSDP modules does not """Tests that passing a module with nested FSDP modules does not
error and still ignores non-FSDP modules' parameters.""" error and still ignores non-FSDP modules' parameters."""
self.run_subtests( self.run_subtests(
{"use_orig_params": [False, True], "ignore_modules": [True, False]}, {
"use_orig_params": [False, True],
"ignore_modules": [True, False],
"composable": [False],
},
self._test_ignored_modules_nested, self._test_ignored_modules_nested,
) )
def _test_ignored_modules_nested(self, use_orig_params: bool, ignore_modules: bool): @skip_if_lt_x_gpu(2)
def test_ignored_modules_nested_composable(self):
"""Tests that passing a module with nested FSDP modules does not
error and still ignores non-FSDP modules' parameters."""
self.run_subtests(
{
"use_orig_params": [True],
"ignore_modules": [True, False],
"composable": [True],
},
self._test_ignored_modules_nested,
)
def _test_ignored_modules_nested(
self, use_orig_params: bool, ignore_modules: bool, composable: bool
):
# Initialize an FSDP-wrapped nested model that first wraps the nested # Initialize an FSDP-wrapped nested model that first wraps the nested
# sequential's second linear layer (`layer1[1]`) and then wraps the # sequential's second linear layer (`layer1[1]`) and then wraps the
# overall model while ignoring the nested sequential (`layer1`) # overall model while ignoring the nested sequential (`layer1`)
model = Model().cuda() model = Model().cuda()
model.layer1[1] = FSDP(model.layer1[1], use_orig_params=use_orig_params) model.layer1[1] = (
FSDP(model.layer1[1], use_orig_params=use_orig_params)
if not composable
else fully_shard(model.layer1[1])
)
if ignore_modules: if ignore_modules:
wrapped_model = FSDP( wrapped_model = (
model, ignored_modules=[model.layer1], use_orig_params=use_orig_params FSDP(
model,
ignored_modules=[model.layer1],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(model, ignored_modules=[model.layer1])
) )
else: else:
wrapped_model = FSDP( wrapped_model = (
model, FSDP(
ignored_parameters=list(model.layer1.parameters()), model,
use_orig_params=use_orig_params, ignored_states=[model.layer1],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
ignored_states=[model.layer1],
)
) )
# Check that the wrapped model's flattened parameter does not include # Check that the wrapped model's flattened parameter does not include
# the ignored nested sequential's parameters # the ignored nested sequential's parameters
@ -193,9 +268,13 @@ class TestFSDPIgnoredModules(FSDPTest):
ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters())
nonignored_numel = total_numel - ignored_numel nonignored_numel = total_numel - ignored_numel
with FSDP.summon_full_params(wrapped_model): with FSDP.summon_full_params(wrapped_model):
flat_param = wrapped_model.params[0] flat_param = (
wrapped_model.params[0]
if not composable
else _get_module_fsdp_state(wrapped_model).params[0]
)
flat_param_numel = flat_param.numel() flat_param_numel = flat_param.numel()
if use_orig_params: if composable or use_orig_params:
# Subtract the numel contributed from alignment padding # Subtract the numel contributed from alignment padding
padding_numel = sum( padding_numel = sum(
numel numel
@ -212,24 +291,29 @@ class TestFSDPIgnoredModules(FSDPTest):
self._train_model(wrapped_model, optim, 3) self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_ignored_modules_invalid(self): @parametrize("composable", [True, False])
def test_ignored_modules_invalid(self, composable):
"""Tests that passing an FSDP module as an ignored module or the """Tests that passing an FSDP module as an ignored module or the
top-level module itself errors.""" top-level module itself errors."""
model = Model().cuda() model = Model().cuda()
model.layer1 = FSDP(model.layer1) wrap_cls = FSDP if composable else fully_shard
model.layer1 = wrap_cls(model.layer1)
# Passing an FSDP module as an ignored module should error # Passing an FSDP module as an ignored module should error
with self.assertRaises( with self.assertRaises(
ValueError, ValueError,
msg="`ignored_modules` should not include FSDP modules", msg="`ignored_modules` should not include FSDP modules",
): ):
FSDP(model, ignored_modules=[model.layer1]) wrap_cls(model, ignored_modules=[model.layer1])
with self.assertWarnsRegex( with self.assertWarnsRegex(
expected_warning=UserWarning, expected_warning=UserWarning,
expected_regex="Trying to ignore the top-level module passed into " expected_regex="Trying to ignore the top-level module passed into "
"the FSDP constructor itself will result in all parameters being " "the FSDP constructor itself will result in all parameters being "
"ignored", "ignored",
): ):
FSDP(model, ignored_modules=[model]) # `fully_shard` does not allow to wrap the same model twice, so create
# a new local model here.
new_model = Model().cuda()
wrap_cls(new_model, ignored_modules=[new_model])
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_diff_ignored_modules_across_ranks(self): def test_diff_ignored_modules_across_ranks(self):
@ -247,16 +331,21 @@ class TestFSDPIgnoredModules(FSDPTest):
{ {
"pass_ignored_modules_to_root": [False, True], "pass_ignored_modules_to_root": [False, True],
"ignore_modules": [True, False], "ignore_modules": [True, False],
"composable": [True, False],
}, },
self._test_diff_ignored_modules_across_ranks, self._test_diff_ignored_modules_across_ranks,
) )
def _test_diff_ignored_modules_across_ranks( def _test_diff_ignored_modules_across_ranks(
self, pass_ignored_modules_to_root: bool, ignore_modules: bool self,
pass_ignored_modules_to_root: bool,
ignore_modules: bool,
composable: bool,
): ):
# To exercise different `FlatParameter` enumerations across ranks, # To exercise different `FlatParameter` enumerations across ranks,
# we wrap `layer3` with FSDP, where `layer3` is registered as a module # we wrap `layer3` with FSDP, where `layer3` is registered as a module
# after `layer1`, which has the variable number of ignored modules # after `layer1`, which has the variable number of ignored modules
wrap_cls = FSDP if composable else fully_shard
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
layer1_ignored_modules = [ layer1_ignored_modules = [
m for m in model.layer1.modules() if isinstance(m, IgnoredModule) m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
@ -265,13 +354,13 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": layer1_ignored_modules} {"ignored_modules": layer1_ignored_modules}
if ignore_modules if ignore_modules
else { else {
"ignored_parameters": { "ignored_states": {
p for m in layer1_ignored_modules for p in m.parameters() p for m in layer1_ignored_modules for p in m.parameters()
} }
} }
) )
model.layer1 = FSDP(model.layer1, **ignore_kwargs) model.layer1 = wrap_cls(model.layer1, **ignore_kwargs)
model.layer3 = FSDP(model.layer3) model.layer3 = wrap_cls(model.layer3)
model_ignored_modules = ( model_ignored_modules = (
[m for m in model.modules() if isinstance(m, IgnoredModule)] [m for m in model.modules() if isinstance(m, IgnoredModule)]
if pass_ignored_modules_to_root if pass_ignored_modules_to_root
@ -281,18 +370,21 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": model_ignored_modules} {"ignored_modules": model_ignored_modules}
if ignore_modules if ignore_modules
else { else {
"ignored_parameters": { "ignored_states": {
p for m in model_ignored_modules for p in m.parameters() p for m in model_ignored_modules for p in m.parameters()
} }
} }
) )
wrapped_model = FSDP(model, **ignore_kwargs_top) wrapped_model = wrap_cls(model, **ignore_kwargs_top)
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
self._train_model(wrapped_model, optim, 3) self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@parametrize("ignore_modules", [True, False]) @parametrize("ignore_modules", [True, False])
def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool): @parametrize("composable", [True, False])
def test_ignored_modules_not_under_wrapped_root(
self, ignore_modules: bool, composable: bool
):
model = Model().cuda() model = Model().cuda()
ignored_modules = list(model.layer1.children())[1:] ignored_modules = list(model.layer1.children())[1:]
@ -300,19 +392,17 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": ignored_modules} {"ignored_modules": ignored_modules}
if ignore_modules if ignore_modules
else { else {
"ignored_parameters": { "ignored_states": {p for m in ignored_modules for p in m.parameters()}
p for m in ignored_modules for p in m.parameters()
}
} }
) )
model.layer1 = FSDP( wrap_cls = FSDP if composable else fully_shard
model.layer1 = wrap_cls(
model.layer1, model.layer1,
# sharding_strategy shouldn't matter here.
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
**ignore_kwargs, **ignore_kwargs,
) )
model.layer3 = FSDP( model.layer3 = wrap_cls(
model.layer3, model.layer3,
# the ignored modules/parameters contains submodule under model.layer1, which # the ignored modules/parameters contains submodule under model.layer1, which
# is out of the local root model.layer3. # is out of the local root model.layer3.

View File

@ -47,6 +47,9 @@ def fully_shard(
param_init_fn: Optional[Callable[[nn.Module], None]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False, sync_module_states: bool = False,
forward_prefetch: bool = False, forward_prefetch: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> nn.Module: ) -> nn.Module:
""" """
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
@ -56,7 +59,7 @@ def fully_shard(
if policy is not None and not isinstance(policy, _FSDPPolicy): if policy is not None and not isinstance(policy, _FSDPPolicy):
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}") raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module) state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules) state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
state = _init_device_handle(state, module, state._ignored_params, device_id) state = _init_device_handle(state, module, state._ignored_params, device_id)
state = _init_process_group_state( state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy state, process_group, ShardingStrategy.FULL_SHARD, policy
@ -96,7 +99,7 @@ def fully_shard(
_register_root_pre_forward_hook(state, module) # prepend last _register_root_pre_forward_hook(state, module) # prepend last
for submodule in module.modules(): for submodule in module.modules():
if ( if (
submodule not in state._ignored_modules submodule in state._fully_sharded_module_to_handles
and _get_module_state(submodule) is None and _get_module_state(submodule) is None
): ):
_insert_module_state(submodule, state) _insert_module_state(submodule, state)

View File

@ -245,12 +245,21 @@ def _init_ignored_module_states(
state: _FSDPState, state: _FSDPState,
module: nn.Module, module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]], ignored_modules: Optional[Iterable[torch.nn.Module]],
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> _FSDPState: ) -> _FSDPState:
assert ( assert (
ignored_modules is None or ignored_parameters is None ignored_modules is None or ignored_states is None
), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \ ), "Can not pass `ignored_modules` and `ignored_states` at the same time. \
Please either pass `ignored_modules` or `ignored_parameters`." Please either pass `ignored_modules` or `ignored_states`."
ignored_parameters = None
if ignored_states:
ignored_states_set = set(ignored_states)
if isinstance(next(iter(ignored_states), None), torch.nn.Parameter):
ignored_parameters = ignored_states_set
else:
ignored_modules = ignored_states_set
state._ignored_modules = _get_ignored_modules(module, ignored_modules) state._ignored_modules = _get_ignored_modules(module, ignored_modules)
state._ignored_params = _get_ignored_params( state._ignored_params = _get_ignored_params(
module, module,
@ -659,7 +668,7 @@ def _get_ignored_modules(
for module in ignored_root_modules: for module in ignored_root_modules:
if not isinstance(module, torch.nn.Module): if not isinstance(module, torch.nn.Module):
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
if isinstance(module, fsdp_file.FullyShardedDataParallel): if _get_module_fsdp_state(module):
# TODO: We may relax this by taking the FSDP instance's wrapped # TODO: We may relax this by taking the FSDP instance's wrapped
# module to provide more flexibility to the user. # module to provide more flexibility to the user.
raise ValueError("`ignored_modules` should not include FSDP modules") raise ValueError("`ignored_modules` should not include FSDP modules")
@ -716,12 +725,12 @@ def _get_ignored_params(
} }
all_ignored_params.update(params_in_ignored_parameters) all_ignored_params.update(params_in_ignored_parameters)
# Include nested FSDP modules' ignored parameters # Always include nested FSDP modules' ignored parameters
for submodule in root_module.modules(): for submodule in root_module.modules():
optional_fsdp_state = _get_module_fsdp_state(submodule) optional_fsdp_state = _get_module_fsdp_state(submodule)
if optional_fsdp_state is not None: if optional_fsdp_state is not None:
assert hasattr(optional_fsdp_state, "_ignored_params") assert hasattr(optional_fsdp_state, "_ignored_params")
all_ignored_params.update(optional_fsdp_state._ignored_params) all_ignored_params.update(optional_fsdp_state._ignored_params)
return all_ignored_params return all_ignored_params

View File

@ -385,11 +385,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
forward_prefetch: bool = False, forward_prefetch: bool = False,
limit_all_gathers: bool = False, limit_all_gathers: bool = False,
use_orig_params: bool = False, use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
): ):
torch._C._log_api_usage_once("torch.distributed.fsdp") torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__() super().__init__()
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters) _init_ignored_module_states(self, module, ignored_modules, ignored_states)
_init_device_handle(self, module, self._ignored_params, device_id) _init_device_handle(self, module, self._ignored_params, device_id)
# Add module annotations for Dynamo support (see function for details) # Add module annotations for Dynamo support (see function for details)