add ignored_states to FSDP/fully_shard (#102056)

Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
This commit is contained in:
Yanli Zhao 2023-05-23 22:35:13 +00:00 committed by PyTorch MergeBot
parent 023bc30b17
commit 956bd03808
4 changed files with 163 additions and 59 deletions

View File

@ -5,7 +5,9 @@ import sys
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
@ -101,12 +103,29 @@ class TestFSDPIgnoredModules(FSDPTest):
"""Tests that ignored modules' parameters are not flattened for a
transformer model with shared parameters."""
self.run_subtests(
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
{
"use_orig_params": [False, True],
"ignore_modules": [True, False],
"composable": [False],
},
self._test_ignored_modules_transformer,
)
@skip_if_lt_x_gpu(2)
def test_ignored_modules_transformer_composable(self):
"""Tests that ignored modules' parameters are not flattened for a
transformer model with shared parameters."""
self.run_subtests(
{
"use_orig_params": [True],
"ignore_modules": [True, False],
"composable": [True],
},
self._test_ignored_modules_transformer,
)
def _test_ignored_modules_transformer(
self, use_orig_params: bool, ignore_modules: bool
self, use_orig_params: bool, ignore_modules: bool, composable: bool
):
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
# the `nn.Transformer` module's parameters
@ -117,19 +136,35 @@ class TestFSDPIgnoredModules(FSDPTest):
deterministic=True,
)
if ignore_modules:
wrapped_model = FSDP(
wrapped_model = (
FSDP(
model,
self.process_group,
ignored_modules=[model.transformer],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
process_group=self.process_group,
ignored_modules=[model.transformer],
)
)
else:
wrapped_model = FSDP(
wrapped_model = (
FSDP(
model,
self.process_group,
ignored_parameters=list(model.transformer.parameters()),
ignored_states=list(model.transformer.parameters()),
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
process_group=self.process_group,
ignored_states=list(model.transformer.parameters()),
)
)
# Check that the wrapped model's flattened parameter does not include
# the ignored transformer module's parameters
nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
@ -144,9 +179,13 @@ class TestFSDPIgnoredModules(FSDPTest):
)
nonignored_numel = total_numel - ignored_numel
with FSDP.summon_full_params(wrapped_model):
flat_param = wrapped_model.params[0]
flat_param = (
wrapped_model.params[0]
if not composable
else _get_module_fsdp_state(wrapped_model).params[0]
)
flat_param_numel = flat_param.numel()
if use_orig_params:
if composable or use_orig_params:
# Subtract the numel contributed from alignment padding
padding_numel = sum(
numel
@ -166,26 +205,62 @@ class TestFSDPIgnoredModules(FSDPTest):
"""Tests that passing a module with nested FSDP modules does not
error and still ignores non-FSDP modules' parameters."""
self.run_subtests(
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
{
"use_orig_params": [False, True],
"ignore_modules": [True, False],
"composable": [False],
},
self._test_ignored_modules_nested,
)
def _test_ignored_modules_nested(self, use_orig_params: bool, ignore_modules: bool):
@skip_if_lt_x_gpu(2)
def test_ignored_modules_nested_composable(self):
"""Tests that passing a module with nested FSDP modules does not
error and still ignores non-FSDP modules' parameters."""
self.run_subtests(
{
"use_orig_params": [True],
"ignore_modules": [True, False],
"composable": [True],
},
self._test_ignored_modules_nested,
)
def _test_ignored_modules_nested(
self, use_orig_params: bool, ignore_modules: bool, composable: bool
):
# Initialize an FSDP-wrapped nested model that first wraps the nested
# sequential's second linear layer (`layer1[1]`) and then wraps the
# overall model while ignoring the nested sequential (`layer1`)
model = Model().cuda()
model.layer1[1] = FSDP(model.layer1[1], use_orig_params=use_orig_params)
model.layer1[1] = (
FSDP(model.layer1[1], use_orig_params=use_orig_params)
if not composable
else fully_shard(model.layer1[1])
)
if ignore_modules:
wrapped_model = FSDP(
model, ignored_modules=[model.layer1], use_orig_params=use_orig_params
wrapped_model = (
FSDP(
model,
ignored_modules=[model.layer1],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(model, ignored_modules=[model.layer1])
)
else:
wrapped_model = FSDP(
wrapped_model = (
FSDP(
model,
ignored_parameters=list(model.layer1.parameters()),
ignored_states=[model.layer1],
use_orig_params=use_orig_params,
)
if not composable
else fully_shard(
model,
ignored_states=[model.layer1],
)
)
# Check that the wrapped model's flattened parameter does not include
# the ignored nested sequential's parameters
nonwrapped_model = Model()
@ -193,9 +268,13 @@ class TestFSDPIgnoredModules(FSDPTest):
ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters())
nonignored_numel = total_numel - ignored_numel
with FSDP.summon_full_params(wrapped_model):
flat_param = wrapped_model.params[0]
flat_param = (
wrapped_model.params[0]
if not composable
else _get_module_fsdp_state(wrapped_model).params[0]
)
flat_param_numel = flat_param.numel()
if use_orig_params:
if composable or use_orig_params:
# Subtract the numel contributed from alignment padding
padding_numel = sum(
numel
@ -212,24 +291,29 @@ class TestFSDPIgnoredModules(FSDPTest):
self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2)
def test_ignored_modules_invalid(self):
@parametrize("composable", [True, False])
def test_ignored_modules_invalid(self, composable):
"""Tests that passing an FSDP module as an ignored module or the
top-level module itself errors."""
model = Model().cuda()
model.layer1 = FSDP(model.layer1)
wrap_cls = FSDP if composable else fully_shard
model.layer1 = wrap_cls(model.layer1)
# Passing an FSDP module as an ignored module should error
with self.assertRaises(
ValueError,
msg="`ignored_modules` should not include FSDP modules",
):
FSDP(model, ignored_modules=[model.layer1])
wrap_cls(model, ignored_modules=[model.layer1])
with self.assertWarnsRegex(
expected_warning=UserWarning,
expected_regex="Trying to ignore the top-level module passed into "
"the FSDP constructor itself will result in all parameters being "
"ignored",
):
FSDP(model, ignored_modules=[model])
# `fully_shard` does not allow to wrap the same model twice, so create
# a new local model here.
new_model = Model().cuda()
wrap_cls(new_model, ignored_modules=[new_model])
@skip_if_lt_x_gpu(2)
def test_diff_ignored_modules_across_ranks(self):
@ -247,16 +331,21 @@ class TestFSDPIgnoredModules(FSDPTest):
{
"pass_ignored_modules_to_root": [False, True],
"ignore_modules": [True, False],
"composable": [True, False],
},
self._test_diff_ignored_modules_across_ranks,
)
def _test_diff_ignored_modules_across_ranks(
self, pass_ignored_modules_to_root: bool, ignore_modules: bool
self,
pass_ignored_modules_to_root: bool,
ignore_modules: bool,
composable: bool,
):
# To exercise different `FlatParameter` enumerations across ranks,
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
# after `layer1`, which has the variable number of ignored modules
wrap_cls = FSDP if composable else fully_shard
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
layer1_ignored_modules = [
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
@ -265,13 +354,13 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": layer1_ignored_modules}
if ignore_modules
else {
"ignored_parameters": {
"ignored_states": {
p for m in layer1_ignored_modules for p in m.parameters()
}
}
)
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
model.layer3 = FSDP(model.layer3)
model.layer1 = wrap_cls(model.layer1, **ignore_kwargs)
model.layer3 = wrap_cls(model.layer3)
model_ignored_modules = (
[m for m in model.modules() if isinstance(m, IgnoredModule)]
if pass_ignored_modules_to_root
@ -281,18 +370,21 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": model_ignored_modules}
if ignore_modules
else {
"ignored_parameters": {
"ignored_states": {
p for m in model_ignored_modules for p in m.parameters()
}
}
)
wrapped_model = FSDP(model, **ignore_kwargs_top)
wrapped_model = wrap_cls(model, **ignore_kwargs_top)
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2)
@parametrize("ignore_modules", [True, False])
def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool):
@parametrize("composable", [True, False])
def test_ignored_modules_not_under_wrapped_root(
self, ignore_modules: bool, composable: bool
):
model = Model().cuda()
ignored_modules = list(model.layer1.children())[1:]
@ -300,19 +392,17 @@ class TestFSDPIgnoredModules(FSDPTest):
{"ignored_modules": ignored_modules}
if ignore_modules
else {
"ignored_parameters": {
p for m in ignored_modules for p in m.parameters()
}
"ignored_states": {p for m in ignored_modules for p in m.parameters()}
}
)
model.layer1 = FSDP(
wrap_cls = FSDP if composable else fully_shard
model.layer1 = wrap_cls(
model.layer1,
# sharding_strategy shouldn't matter here.
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
**ignore_kwargs,
)
model.layer3 = FSDP(
model.layer3 = wrap_cls(
model.layer3,
# the ignored modules/parameters contains submodule under model.layer1, which
# is out of the local root model.layer3.

View File

@ -47,6 +47,9 @@ def fully_shard(
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> nn.Module:
"""
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
@ -56,7 +59,7 @@ def fully_shard(
if policy is not None and not isinstance(policy, _FSDPPolicy):
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules)
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
state = _init_device_handle(state, module, state._ignored_params, device_id)
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
@ -96,7 +99,7 @@ def fully_shard(
_register_root_pre_forward_hook(state, module) # prepend last
for submodule in module.modules():
if (
submodule not in state._ignored_modules
submodule in state._fully_sharded_module_to_handles
and _get_module_state(submodule) is None
):
_insert_module_state(submodule, state)

View File

@ -245,12 +245,21 @@ def _init_ignored_module_states(
state: _FSDPState,
module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]],
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> _FSDPState:
assert (
ignored_modules is None or ignored_parameters is None
), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \
Please either pass `ignored_modules` or `ignored_parameters`."
ignored_modules is None or ignored_states is None
), "Can not pass `ignored_modules` and `ignored_states` at the same time. \
Please either pass `ignored_modules` or `ignored_states`."
ignored_parameters = None
if ignored_states:
ignored_states_set = set(ignored_states)
if isinstance(next(iter(ignored_states), None), torch.nn.Parameter):
ignored_parameters = ignored_states_set
else:
ignored_modules = ignored_states_set
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
state._ignored_params = _get_ignored_params(
module,
@ -659,7 +668,7 @@ def _get_ignored_modules(
for module in ignored_root_modules:
if not isinstance(module, torch.nn.Module):
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
if isinstance(module, fsdp_file.FullyShardedDataParallel):
if _get_module_fsdp_state(module):
# TODO: We may relax this by taking the FSDP instance's wrapped
# module to provide more flexibility to the user.
raise ValueError("`ignored_modules` should not include FSDP modules")
@ -716,7 +725,7 @@ def _get_ignored_params(
}
all_ignored_params.update(params_in_ignored_parameters)
# Include nested FSDP modules' ignored parameters
# Always include nested FSDP modules' ignored parameters
for submodule in root_module.modules():
optional_fsdp_state = _get_module_fsdp_state(submodule)
if optional_fsdp_state is not None:

View File

@ -385,11 +385,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
):
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
_init_ignored_module_states(self, module, ignored_modules, ignored_states)
_init_device_handle(self, module, self._ignored_params, device_id)
# Add module annotations for Dynamo support (see function for details)