mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add ignored_states to FSDP/fully_shard (#102056)
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056 Approved by: https://github.com/awgu
This commit is contained in:
parent
023bc30b17
commit
956bd03808
|
|
@ -5,7 +5,9 @@ import sys
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
from torch.distributed._composable import fully_shard
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
CUDAInitMode,
|
CUDAInitMode,
|
||||||
|
|
@ -101,12 +103,29 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
"""Tests that ignored modules' parameters are not flattened for a
|
"""Tests that ignored modules' parameters are not flattened for a
|
||||||
transformer model with shared parameters."""
|
transformer model with shared parameters."""
|
||||||
self.run_subtests(
|
self.run_subtests(
|
||||||
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
|
{
|
||||||
|
"use_orig_params": [False, True],
|
||||||
|
"ignore_modules": [True, False],
|
||||||
|
"composable": [False],
|
||||||
|
},
|
||||||
|
self._test_ignored_modules_transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_ignored_modules_transformer_composable(self):
|
||||||
|
"""Tests that ignored modules' parameters are not flattened for a
|
||||||
|
transformer model with shared parameters."""
|
||||||
|
self.run_subtests(
|
||||||
|
{
|
||||||
|
"use_orig_params": [True],
|
||||||
|
"ignore_modules": [True, False],
|
||||||
|
"composable": [True],
|
||||||
|
},
|
||||||
self._test_ignored_modules_transformer,
|
self._test_ignored_modules_transformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_ignored_modules_transformer(
|
def _test_ignored_modules_transformer(
|
||||||
self, use_orig_params: bool, ignore_modules: bool
|
self, use_orig_params: bool, ignore_modules: bool, composable: bool
|
||||||
):
|
):
|
||||||
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
|
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
|
||||||
# the `nn.Transformer` module's parameters
|
# the `nn.Transformer` module's parameters
|
||||||
|
|
@ -117,18 +136,34 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
if ignore_modules:
|
if ignore_modules:
|
||||||
wrapped_model = FSDP(
|
wrapped_model = (
|
||||||
model,
|
FSDP(
|
||||||
self.process_group,
|
model,
|
||||||
ignored_modules=[model.transformer],
|
self.process_group,
|
||||||
use_orig_params=use_orig_params,
|
ignored_modules=[model.transformer],
|
||||||
|
use_orig_params=use_orig_params,
|
||||||
|
)
|
||||||
|
if not composable
|
||||||
|
else fully_shard(
|
||||||
|
model,
|
||||||
|
process_group=self.process_group,
|
||||||
|
ignored_modules=[model.transformer],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
wrapped_model = FSDP(
|
wrapped_model = (
|
||||||
model,
|
FSDP(
|
||||||
self.process_group,
|
model,
|
||||||
ignored_parameters=list(model.transformer.parameters()),
|
self.process_group,
|
||||||
use_orig_params=use_orig_params,
|
ignored_states=list(model.transformer.parameters()),
|
||||||
|
use_orig_params=use_orig_params,
|
||||||
|
)
|
||||||
|
if not composable
|
||||||
|
else fully_shard(
|
||||||
|
model,
|
||||||
|
process_group=self.process_group,
|
||||||
|
ignored_states=list(model.transformer.parameters()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Check that the wrapped model's flattened parameter does not include
|
# Check that the wrapped model's flattened parameter does not include
|
||||||
# the ignored transformer module's parameters
|
# the ignored transformer module's parameters
|
||||||
|
|
@ -144,9 +179,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
)
|
)
|
||||||
nonignored_numel = total_numel - ignored_numel
|
nonignored_numel = total_numel - ignored_numel
|
||||||
with FSDP.summon_full_params(wrapped_model):
|
with FSDP.summon_full_params(wrapped_model):
|
||||||
flat_param = wrapped_model.params[0]
|
flat_param = (
|
||||||
|
wrapped_model.params[0]
|
||||||
|
if not composable
|
||||||
|
else _get_module_fsdp_state(wrapped_model).params[0]
|
||||||
|
)
|
||||||
flat_param_numel = flat_param.numel()
|
flat_param_numel = flat_param.numel()
|
||||||
if use_orig_params:
|
if composable or use_orig_params:
|
||||||
# Subtract the numel contributed from alignment padding
|
# Subtract the numel contributed from alignment padding
|
||||||
padding_numel = sum(
|
padding_numel = sum(
|
||||||
numel
|
numel
|
||||||
|
|
@ -166,25 +205,61 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
"""Tests that passing a module with nested FSDP modules does not
|
"""Tests that passing a module with nested FSDP modules does not
|
||||||
error and still ignores non-FSDP modules' parameters."""
|
error and still ignores non-FSDP modules' parameters."""
|
||||||
self.run_subtests(
|
self.run_subtests(
|
||||||
{"use_orig_params": [False, True], "ignore_modules": [True, False]},
|
{
|
||||||
|
"use_orig_params": [False, True],
|
||||||
|
"ignore_modules": [True, False],
|
||||||
|
"composable": [False],
|
||||||
|
},
|
||||||
self._test_ignored_modules_nested,
|
self._test_ignored_modules_nested,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_ignored_modules_nested(self, use_orig_params: bool, ignore_modules: bool):
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_ignored_modules_nested_composable(self):
|
||||||
|
"""Tests that passing a module with nested FSDP modules does not
|
||||||
|
error and still ignores non-FSDP modules' parameters."""
|
||||||
|
self.run_subtests(
|
||||||
|
{
|
||||||
|
"use_orig_params": [True],
|
||||||
|
"ignore_modules": [True, False],
|
||||||
|
"composable": [True],
|
||||||
|
},
|
||||||
|
self._test_ignored_modules_nested,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_ignored_modules_nested(
|
||||||
|
self, use_orig_params: bool, ignore_modules: bool, composable: bool
|
||||||
|
):
|
||||||
# Initialize an FSDP-wrapped nested model that first wraps the nested
|
# Initialize an FSDP-wrapped nested model that first wraps the nested
|
||||||
# sequential's second linear layer (`layer1[1]`) and then wraps the
|
# sequential's second linear layer (`layer1[1]`) and then wraps the
|
||||||
# overall model while ignoring the nested sequential (`layer1`)
|
# overall model while ignoring the nested sequential (`layer1`)
|
||||||
model = Model().cuda()
|
model = Model().cuda()
|
||||||
model.layer1[1] = FSDP(model.layer1[1], use_orig_params=use_orig_params)
|
model.layer1[1] = (
|
||||||
|
FSDP(model.layer1[1], use_orig_params=use_orig_params)
|
||||||
|
if not composable
|
||||||
|
else fully_shard(model.layer1[1])
|
||||||
|
)
|
||||||
if ignore_modules:
|
if ignore_modules:
|
||||||
wrapped_model = FSDP(
|
wrapped_model = (
|
||||||
model, ignored_modules=[model.layer1], use_orig_params=use_orig_params
|
FSDP(
|
||||||
|
model,
|
||||||
|
ignored_modules=[model.layer1],
|
||||||
|
use_orig_params=use_orig_params,
|
||||||
|
)
|
||||||
|
if not composable
|
||||||
|
else fully_shard(model, ignored_modules=[model.layer1])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
wrapped_model = FSDP(
|
wrapped_model = (
|
||||||
model,
|
FSDP(
|
||||||
ignored_parameters=list(model.layer1.parameters()),
|
model,
|
||||||
use_orig_params=use_orig_params,
|
ignored_states=[model.layer1],
|
||||||
|
use_orig_params=use_orig_params,
|
||||||
|
)
|
||||||
|
if not composable
|
||||||
|
else fully_shard(
|
||||||
|
model,
|
||||||
|
ignored_states=[model.layer1],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Check that the wrapped model's flattened parameter does not include
|
# Check that the wrapped model's flattened parameter does not include
|
||||||
# the ignored nested sequential's parameters
|
# the ignored nested sequential's parameters
|
||||||
|
|
@ -193,9 +268,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters())
|
ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters())
|
||||||
nonignored_numel = total_numel - ignored_numel
|
nonignored_numel = total_numel - ignored_numel
|
||||||
with FSDP.summon_full_params(wrapped_model):
|
with FSDP.summon_full_params(wrapped_model):
|
||||||
flat_param = wrapped_model.params[0]
|
flat_param = (
|
||||||
|
wrapped_model.params[0]
|
||||||
|
if not composable
|
||||||
|
else _get_module_fsdp_state(wrapped_model).params[0]
|
||||||
|
)
|
||||||
flat_param_numel = flat_param.numel()
|
flat_param_numel = flat_param.numel()
|
||||||
if use_orig_params:
|
if composable or use_orig_params:
|
||||||
# Subtract the numel contributed from alignment padding
|
# Subtract the numel contributed from alignment padding
|
||||||
padding_numel = sum(
|
padding_numel = sum(
|
||||||
numel
|
numel
|
||||||
|
|
@ -212,24 +291,29 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
self._train_model(wrapped_model, optim, 3)
|
self._train_model(wrapped_model, optim, 3)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_ignored_modules_invalid(self):
|
@parametrize("composable", [True, False])
|
||||||
|
def test_ignored_modules_invalid(self, composable):
|
||||||
"""Tests that passing an FSDP module as an ignored module or the
|
"""Tests that passing an FSDP module as an ignored module or the
|
||||||
top-level module itself errors."""
|
top-level module itself errors."""
|
||||||
model = Model().cuda()
|
model = Model().cuda()
|
||||||
model.layer1 = FSDP(model.layer1)
|
wrap_cls = FSDP if composable else fully_shard
|
||||||
|
model.layer1 = wrap_cls(model.layer1)
|
||||||
# Passing an FSDP module as an ignored module should error
|
# Passing an FSDP module as an ignored module should error
|
||||||
with self.assertRaises(
|
with self.assertRaises(
|
||||||
ValueError,
|
ValueError,
|
||||||
msg="`ignored_modules` should not include FSDP modules",
|
msg="`ignored_modules` should not include FSDP modules",
|
||||||
):
|
):
|
||||||
FSDP(model, ignored_modules=[model.layer1])
|
wrap_cls(model, ignored_modules=[model.layer1])
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
expected_warning=UserWarning,
|
expected_warning=UserWarning,
|
||||||
expected_regex="Trying to ignore the top-level module passed into "
|
expected_regex="Trying to ignore the top-level module passed into "
|
||||||
"the FSDP constructor itself will result in all parameters being "
|
"the FSDP constructor itself will result in all parameters being "
|
||||||
"ignored",
|
"ignored",
|
||||||
):
|
):
|
||||||
FSDP(model, ignored_modules=[model])
|
# `fully_shard` does not allow to wrap the same model twice, so create
|
||||||
|
# a new local model here.
|
||||||
|
new_model = Model().cuda()
|
||||||
|
wrap_cls(new_model, ignored_modules=[new_model])
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_diff_ignored_modules_across_ranks(self):
|
def test_diff_ignored_modules_across_ranks(self):
|
||||||
|
|
@ -247,16 +331,21 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
{
|
{
|
||||||
"pass_ignored_modules_to_root": [False, True],
|
"pass_ignored_modules_to_root": [False, True],
|
||||||
"ignore_modules": [True, False],
|
"ignore_modules": [True, False],
|
||||||
|
"composable": [True, False],
|
||||||
},
|
},
|
||||||
self._test_diff_ignored_modules_across_ranks,
|
self._test_diff_ignored_modules_across_ranks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_diff_ignored_modules_across_ranks(
|
def _test_diff_ignored_modules_across_ranks(
|
||||||
self, pass_ignored_modules_to_root: bool, ignore_modules: bool
|
self,
|
||||||
|
pass_ignored_modules_to_root: bool,
|
||||||
|
ignore_modules: bool,
|
||||||
|
composable: bool,
|
||||||
):
|
):
|
||||||
# To exercise different `FlatParameter` enumerations across ranks,
|
# To exercise different `FlatParameter` enumerations across ranks,
|
||||||
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
|
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
|
||||||
# after `layer1`, which has the variable number of ignored modules
|
# after `layer1`, which has the variable number of ignored modules
|
||||||
|
wrap_cls = FSDP if composable else fully_shard
|
||||||
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
|
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
|
||||||
layer1_ignored_modules = [
|
layer1_ignored_modules = [
|
||||||
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
|
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
|
||||||
|
|
@ -265,13 +354,13 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
{"ignored_modules": layer1_ignored_modules}
|
{"ignored_modules": layer1_ignored_modules}
|
||||||
if ignore_modules
|
if ignore_modules
|
||||||
else {
|
else {
|
||||||
"ignored_parameters": {
|
"ignored_states": {
|
||||||
p for m in layer1_ignored_modules for p in m.parameters()
|
p for m in layer1_ignored_modules for p in m.parameters()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
|
model.layer1 = wrap_cls(model.layer1, **ignore_kwargs)
|
||||||
model.layer3 = FSDP(model.layer3)
|
model.layer3 = wrap_cls(model.layer3)
|
||||||
model_ignored_modules = (
|
model_ignored_modules = (
|
||||||
[m for m in model.modules() if isinstance(m, IgnoredModule)]
|
[m for m in model.modules() if isinstance(m, IgnoredModule)]
|
||||||
if pass_ignored_modules_to_root
|
if pass_ignored_modules_to_root
|
||||||
|
|
@ -281,18 +370,21 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
{"ignored_modules": model_ignored_modules}
|
{"ignored_modules": model_ignored_modules}
|
||||||
if ignore_modules
|
if ignore_modules
|
||||||
else {
|
else {
|
||||||
"ignored_parameters": {
|
"ignored_states": {
|
||||||
p for m in model_ignored_modules for p in m.parameters()
|
p for m in model_ignored_modules for p in m.parameters()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
wrapped_model = FSDP(model, **ignore_kwargs_top)
|
wrapped_model = wrap_cls(model, **ignore_kwargs_top)
|
||||||
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
||||||
self._train_model(wrapped_model, optim, 3)
|
self._train_model(wrapped_model, optim, 3)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@parametrize("ignore_modules", [True, False])
|
@parametrize("ignore_modules", [True, False])
|
||||||
def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool):
|
@parametrize("composable", [True, False])
|
||||||
|
def test_ignored_modules_not_under_wrapped_root(
|
||||||
|
self, ignore_modules: bool, composable: bool
|
||||||
|
):
|
||||||
model = Model().cuda()
|
model = Model().cuda()
|
||||||
ignored_modules = list(model.layer1.children())[1:]
|
ignored_modules = list(model.layer1.children())[1:]
|
||||||
|
|
||||||
|
|
@ -300,19 +392,17 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||||
{"ignored_modules": ignored_modules}
|
{"ignored_modules": ignored_modules}
|
||||||
if ignore_modules
|
if ignore_modules
|
||||||
else {
|
else {
|
||||||
"ignored_parameters": {
|
"ignored_states": {p for m in ignored_modules for p in m.parameters()}
|
||||||
p for m in ignored_modules for p in m.parameters()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
model.layer1 = FSDP(
|
wrap_cls = FSDP if composable else fully_shard
|
||||||
|
|
||||||
|
model.layer1 = wrap_cls(
|
||||||
model.layer1,
|
model.layer1,
|
||||||
# sharding_strategy shouldn't matter here.
|
|
||||||
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
|
||||||
**ignore_kwargs,
|
**ignore_kwargs,
|
||||||
)
|
)
|
||||||
model.layer3 = FSDP(
|
model.layer3 = wrap_cls(
|
||||||
model.layer3,
|
model.layer3,
|
||||||
# the ignored modules/parameters contains submodule under model.layer1, which
|
# the ignored modules/parameters contains submodule under model.layer1, which
|
||||||
# is out of the local root model.layer3.
|
# is out of the local root model.layer3.
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,9 @@ def fully_shard(
|
||||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||||
sync_module_states: bool = False,
|
sync_module_states: bool = False,
|
||||||
forward_prefetch: bool = False,
|
forward_prefetch: bool = False,
|
||||||
|
ignored_states: Union[
|
||||||
|
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||||
|
] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
|
||||||
|
|
@ -56,7 +59,7 @@ def fully_shard(
|
||||||
if policy is not None and not isinstance(policy, _FSDPPolicy):
|
if policy is not None and not isinstance(policy, _FSDPPolicy):
|
||||||
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
|
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
|
||||||
state = fully_shard.state(module)
|
state = fully_shard.state(module)
|
||||||
state = _init_ignored_module_states(state, module, ignored_modules)
|
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
|
||||||
state = _init_device_handle(state, module, state._ignored_params, device_id)
|
state = _init_device_handle(state, module, state._ignored_params, device_id)
|
||||||
state = _init_process_group_state(
|
state = _init_process_group_state(
|
||||||
state, process_group, ShardingStrategy.FULL_SHARD, policy
|
state, process_group, ShardingStrategy.FULL_SHARD, policy
|
||||||
|
|
@ -96,7 +99,7 @@ def fully_shard(
|
||||||
_register_root_pre_forward_hook(state, module) # prepend last
|
_register_root_pre_forward_hook(state, module) # prepend last
|
||||||
for submodule in module.modules():
|
for submodule in module.modules():
|
||||||
if (
|
if (
|
||||||
submodule not in state._ignored_modules
|
submodule in state._fully_sharded_module_to_handles
|
||||||
and _get_module_state(submodule) is None
|
and _get_module_state(submodule) is None
|
||||||
):
|
):
|
||||||
_insert_module_state(submodule, state)
|
_insert_module_state(submodule, state)
|
||||||
|
|
|
||||||
|
|
@ -245,12 +245,21 @@ def _init_ignored_module_states(
|
||||||
state: _FSDPState,
|
state: _FSDPState,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
ignored_modules: Optional[Iterable[torch.nn.Module]],
|
ignored_modules: Optional[Iterable[torch.nn.Module]],
|
||||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
ignored_states: Union[
|
||||||
|
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||||
|
] = None,
|
||||||
) -> _FSDPState:
|
) -> _FSDPState:
|
||||||
assert (
|
assert (
|
||||||
ignored_modules is None or ignored_parameters is None
|
ignored_modules is None or ignored_states is None
|
||||||
), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \
|
), "Can not pass `ignored_modules` and `ignored_states` at the same time. \
|
||||||
Please either pass `ignored_modules` or `ignored_parameters`."
|
Please either pass `ignored_modules` or `ignored_states`."
|
||||||
|
ignored_parameters = None
|
||||||
|
if ignored_states:
|
||||||
|
ignored_states_set = set(ignored_states)
|
||||||
|
if isinstance(next(iter(ignored_states), None), torch.nn.Parameter):
|
||||||
|
ignored_parameters = ignored_states_set
|
||||||
|
else:
|
||||||
|
ignored_modules = ignored_states_set
|
||||||
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
|
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
|
||||||
state._ignored_params = _get_ignored_params(
|
state._ignored_params = _get_ignored_params(
|
||||||
module,
|
module,
|
||||||
|
|
@ -659,7 +668,7 @@ def _get_ignored_modules(
|
||||||
for module in ignored_root_modules:
|
for module in ignored_root_modules:
|
||||||
if not isinstance(module, torch.nn.Module):
|
if not isinstance(module, torch.nn.Module):
|
||||||
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
||||||
if isinstance(module, fsdp_file.FullyShardedDataParallel):
|
if _get_module_fsdp_state(module):
|
||||||
# TODO: We may relax this by taking the FSDP instance's wrapped
|
# TODO: We may relax this by taking the FSDP instance's wrapped
|
||||||
# module to provide more flexibility to the user.
|
# module to provide more flexibility to the user.
|
||||||
raise ValueError("`ignored_modules` should not include FSDP modules")
|
raise ValueError("`ignored_modules` should not include FSDP modules")
|
||||||
|
|
@ -716,12 +725,12 @@ def _get_ignored_params(
|
||||||
}
|
}
|
||||||
all_ignored_params.update(params_in_ignored_parameters)
|
all_ignored_params.update(params_in_ignored_parameters)
|
||||||
|
|
||||||
# Include nested FSDP modules' ignored parameters
|
# Always include nested FSDP modules' ignored parameters
|
||||||
for submodule in root_module.modules():
|
for submodule in root_module.modules():
|
||||||
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
||||||
if optional_fsdp_state is not None:
|
if optional_fsdp_state is not None:
|
||||||
assert hasattr(optional_fsdp_state, "_ignored_params")
|
assert hasattr(optional_fsdp_state, "_ignored_params")
|
||||||
all_ignored_params.update(optional_fsdp_state._ignored_params)
|
all_ignored_params.update(optional_fsdp_state._ignored_params)
|
||||||
|
|
||||||
return all_ignored_params
|
return all_ignored_params
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -385,11 +385,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
forward_prefetch: bool = False,
|
forward_prefetch: bool = False,
|
||||||
limit_all_gathers: bool = False,
|
limit_all_gathers: bool = False,
|
||||||
use_orig_params: bool = False,
|
use_orig_params: bool = False,
|
||||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
ignored_states: Union[
|
||||||
|
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
torch._C._log_api_usage_once("torch.distributed.fsdp")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
_init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
|
_init_ignored_module_states(self, module, ignored_modules, ignored_states)
|
||||||
_init_device_handle(self, module, self._ignored_params, device_id)
|
_init_device_handle(self, module, self._ignored_params, device_id)
|
||||||
|
|
||||||
# Add module annotations for Dynamo support (see function for details)
|
# Add module annotations for Dynamo support (see function for details)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user