mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules). The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case. Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448 Approved by: https://github.com/zhaojuanmao
This commit is contained in:
parent
0eb59ad093
commit
3de7fd461a
|
|
@ -1103,6 +1103,33 @@ class TestFSDPStateDict(FSDPTest):
|
|||
else:
|
||||
self.assertEqual(v, state_dict[k])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_shared_module_and_shared_parameter(self):
|
||||
class TestDummyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
|
||||
self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
|
||||
self.net3 = self.net2
|
||||
self.random_parameter = nn.Parameter(torch.Tensor(10))
|
||||
self.shared_parameter = self.random_parameter
|
||||
|
||||
def forward(self, x):
|
||||
return self.net3(self.net2(self.net1(x)))
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(8, 8, device="cuda")
|
||||
|
||||
model = FSDP(TestDummyModel().cuda())
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
||||
state_dict = model.state_dict()
|
||||
self.assertEqual(
|
||||
state_dict["random_parameter"], state_dict["shared_parameter"]
|
||||
)
|
||||
self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])
|
||||
self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestFSDPStateDict)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import traceback
|
|||
import warnings
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
|
|
@ -14,6 +15,7 @@ from typing import (
|
|||
no_type_check,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
|
@ -184,6 +186,25 @@ def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
|
|||
return getattr(tensor, FSDP_FLATTENED, False)
|
||||
|
||||
|
||||
def _named_parameters_with_duplicates(
|
||||
module: nn.Module, **kwargs: Any
|
||||
) -> List[Tuple[str, nn.Parameter]]:
|
||||
"""
|
||||
This API is required as some modules overwrite `named_parameters()` but do not support
|
||||
`remove_duplicate`.
|
||||
"""
|
||||
assert (
|
||||
"remove_duplicate" not in kwargs
|
||||
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||
kwargs["remove_duplicate"] = False
|
||||
try:
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
except AssertionError as e:
|
||||
kwargs.pop("remove_duplicate")
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
return ret
|
||||
|
||||
|
||||
def _get_param_to_fqns(
|
||||
model: torch.nn.Module,
|
||||
dedup_shared_params: bool = True,
|
||||
|
|
@ -205,7 +226,9 @@ def _get_param_to_fqns(
|
|||
"""
|
||||
|
||||
def module_fn(module, prefix, tree_level, param_to_fqns):
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
module, recurse=False
|
||||
):
|
||||
local_fqns = (
|
||||
param._fqns
|
||||
if type(param) is flat_param_file.FlatParameter
|
||||
|
|
@ -247,7 +270,7 @@ def _get_param_to_fqns(
|
|||
model,
|
||||
module_fn,
|
||||
return_fn,
|
||||
[key for key, _ in model.named_parameters()],
|
||||
[key for key, _ in _named_parameters_with_duplicates(model)],
|
||||
param_to_unflat_param_names,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from torch.distributed.fsdp._common_utils import (
|
|||
_FSDPState,
|
||||
_get_module_fsdp_state,
|
||||
_is_fsdp_flattened,
|
||||
_named_parameters_with_duplicates,
|
||||
clean_tensor_name,
|
||||
TrainingState,
|
||||
)
|
||||
|
|
@ -571,7 +572,8 @@ def _get_state_names_for_states(
|
|||
param_names: List[str] = []
|
||||
buffer_names: List[str] = []
|
||||
param_to_param_name = {
|
||||
param: param_name for param_name, param in module.named_parameters()
|
||||
param: param_name
|
||||
for param_name, param in _named_parameters_with_duplicates(module)
|
||||
}
|
||||
buffer_to_buffer_name = {
|
||||
buffer: buffer_name for buffer_name, buffer in module.named_buffers()
|
||||
|
|
@ -992,7 +994,7 @@ def _check_orig_params_flattened(
|
|||
``fsdp_module``. This should be called as a sanity check after flattening
|
||||
the wrapped module's parameters.
|
||||
"""
|
||||
for param_name, param in fsdp_module.named_parameters():
|
||||
for param_name, param in _named_parameters_with_duplicates(fsdp_module):
|
||||
if param not in ignored_params and not _is_fsdp_flattened(param):
|
||||
raise RuntimeError(
|
||||
f"Found an unflattened parameter: {param_name}; "
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from torch.distributed.fsdp._common_utils import (
|
|||
_get_module_fsdp_state_if_fully_sharded_module,
|
||||
_get_param_to_fqns,
|
||||
_module_handles,
|
||||
_named_parameters_with_duplicates,
|
||||
clean_tensor_name,
|
||||
)
|
||||
from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
|
||||
|
|
@ -977,7 +978,9 @@ def _get_param_id_to_param_from_optim_input(
|
|||
|
||||
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
|
||||
def module_fn(module, prefix, tree_level, flat_param_to_fqn):
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
module, recurse=False
|
||||
):
|
||||
if type(param) is not FlatParameter:
|
||||
continue
|
||||
fqn = clean_tensor_name(prefix + param_name)
|
||||
|
|
@ -991,7 +994,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
|
|||
model,
|
||||
module_fn,
|
||||
return_fn,
|
||||
[fqn for fqn, _ in model.named_parameters()],
|
||||
[fqn for fqn, _ in _named_parameters_with_duplicates(model)],
|
||||
flat_param_to_fqn_ret,
|
||||
)
|
||||
|
||||
|
|
@ -1015,7 +1018,7 @@ def _get_param_key_to_param(
|
|||
param_to_fqns is not None and flat_param_to_fqn is not None
|
||||
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||
assert model is not None
|
||||
for key, _ in model.named_parameters():
|
||||
for key, _ in _named_parameters_with_duplicates(model):
|
||||
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
||||
|
||||
param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
|
||||
|
|
@ -1444,7 +1447,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
|
|||
model,
|
||||
module_fn,
|
||||
return_fn,
|
||||
[fqn for fqn, _ in model.named_parameters()],
|
||||
[fqn for fqn, _ in _named_parameters_with_duplicates(model)],
|
||||
fqn_to_param_info,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from itertools import accumulate, chain
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
|
|
@ -28,6 +29,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_named_parameters_with_duplicates,
|
||||
_set_fsdp_flattened,
|
||||
HandleTrainingState,
|
||||
)
|
||||
|
|
@ -523,8 +525,10 @@ class FlatParamHandle:
|
|||
param_extensions: List[Any] = []
|
||||
is_padding_mask: List[bool] = []
|
||||
total_numel = total_numel_without_padding = 0
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for param_name, param in submodule.named_parameters(recurse=False):
|
||||
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
submodule, recurse=False
|
||||
):
|
||||
if param not in params_set:
|
||||
continue
|
||||
if param in shared_param_memo: # shared reference
|
||||
|
|
@ -553,7 +557,8 @@ class FlatParamHandle:
|
|||
is_padding_mask.append(True)
|
||||
numels.append(numel_to_pad)
|
||||
total_numel += numel_to_pad
|
||||
param, extension = _ext_pre_flatten_transform(param)
|
||||
transform_t, extension = _ext_pre_flatten_transform(param)
|
||||
param = cast(nn.Parameter, transform_t)
|
||||
param_extensions.append(extension)
|
||||
shared_param_memo[param] = (submodule, submodule_name, param_name)
|
||||
params_to_flatten.append(param)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user