[FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)

The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
This commit is contained in:
Chien-Chin Huang 2023-04-24 13:42:12 -07:00 committed by PyTorch MergeBot
parent 0eb59ad093
commit 3de7fd461a
5 changed files with 71 additions and 11 deletions

View File

@ -1103,6 +1103,33 @@ class TestFSDPStateDict(FSDPTest):
else:
self.assertEqual(v, state_dict[k])
@skip_if_lt_x_gpu(2)
def test_shared_module_and_shared_parameter(self):
class TestDummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
self.net3 = self.net2
self.random_parameter = nn.Parameter(torch.Tensor(10))
self.shared_parameter = self.random_parameter
def forward(self, x):
return self.net3(self.net2(self.net1(x)))
def get_input(self):
return torch.rand(8, 8, device="cuda")
model = FSDP(TestDummyModel().cuda())
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = model.state_dict()
self.assertEqual(
state_dict["random_parameter"], state_dict["shared_parameter"]
)
self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])
self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])
instantiate_parametrized_tests(TestFSDPStateDict)

View File

@ -6,6 +6,7 @@ import traceback
import warnings
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Generator,
@ -14,6 +15,7 @@ from typing import (
no_type_check,
Optional,
Set,
Tuple,
)
import torch
@ -184,6 +186,25 @@ def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
return getattr(tensor, FSDP_FLATTENED, False)
def _named_parameters_with_duplicates(
module: nn.Module, **kwargs: Any
) -> List[Tuple[str, nn.Parameter]]:
"""
This API is required as some modules overwrite `named_parameters()` but do not support
`remove_duplicate`.
"""
assert (
"remove_duplicate" not in kwargs
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
kwargs["remove_duplicate"] = False
try:
ret = list(module.named_parameters(**kwargs))
except AssertionError as e:
kwargs.pop("remove_duplicate")
ret = list(module.named_parameters(**kwargs))
return ret
def _get_param_to_fqns(
model: torch.nn.Module,
dedup_shared_params: bool = True,
@ -205,7 +226,9 @@ def _get_param_to_fqns(
"""
def module_fn(module, prefix, tree_level, param_to_fqns):
for param_name, param in module.named_parameters(recurse=False):
for param_name, param in _named_parameters_with_duplicates(
module, recurse=False
):
local_fqns = (
param._fqns
if type(param) is flat_param_file.FlatParameter
@ -247,7 +270,7 @@ def _get_param_to_fqns(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
[key for key, _ in _named_parameters_with_duplicates(model)],
param_to_unflat_param_names,
)

View File

@ -29,6 +29,7 @@ from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state,
_is_fsdp_flattened,
_named_parameters_with_duplicates,
clean_tensor_name,
TrainingState,
)
@ -571,7 +572,8 @@ def _get_state_names_for_states(
param_names: List[str] = []
buffer_names: List[str] = []
param_to_param_name = {
param: param_name for param_name, param in module.named_parameters()
param: param_name
for param_name, param in _named_parameters_with_duplicates(module)
}
buffer_to_buffer_name = {
buffer: buffer_name for buffer_name, buffer in module.named_buffers()
@ -992,7 +994,7 @@ def _check_orig_params_flattened(
``fsdp_module``. This should be called as a sanity check after flattening
the wrapped module's parameters.
"""
for param_name, param in fsdp_module.named_parameters():
for param_name, param in _named_parameters_with_duplicates(fsdp_module):
if param not in ignored_params and not _is_fsdp_flattened(param):
raise RuntimeError(
f"Found an unflattened parameter: {param_name}; "

View File

@ -28,6 +28,7 @@ from torch.distributed.fsdp._common_utils import (
_get_module_fsdp_state_if_fully_sharded_module,
_get_param_to_fqns,
_module_handles,
_named_parameters_with_duplicates,
clean_tensor_name,
)
from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
@ -977,7 +978,9 @@ def _get_param_id_to_param_from_optim_input(
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
def module_fn(module, prefix, tree_level, flat_param_to_fqn):
for param_name, param in module.named_parameters(recurse=False):
for param_name, param in _named_parameters_with_duplicates(
module, recurse=False
):
if type(param) is not FlatParameter:
continue
fqn = clean_tensor_name(prefix + param_name)
@ -991,7 +994,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
model,
module_fn,
return_fn,
[fqn for fqn, _ in model.named_parameters()],
[fqn for fqn, _ in _named_parameters_with_duplicates(model)],
flat_param_to_fqn_ret,
)
@ -1015,7 +1018,7 @@ def _get_param_key_to_param(
param_to_fqns is not None and flat_param_to_fqn is not None
), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
assert model is not None
for key, _ in model.named_parameters():
for key, _ in _named_parameters_with_duplicates(model):
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
@ -1444,7 +1447,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
model,
module_fn,
return_fn,
[fqn for fqn, _ in model.named_parameters()],
[fqn for fqn, _ in _named_parameters_with_duplicates(model)],
fqn_to_param_info,
)

View File

@ -8,6 +8,7 @@ from itertools import accumulate, chain
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
@ -28,6 +29,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp._common_utils import (
_named_parameters_with_duplicates,
_set_fsdp_flattened,
HandleTrainingState,
)
@ -523,8 +525,10 @@ class FlatParamHandle:
param_extensions: List[Any] = []
is_padding_mask: List[bool] = []
total_numel = total_numel_without_padding = 0
for submodule_name, submodule in module.named_modules():
for param_name, param in submodule.named_parameters(recurse=False):
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
submodule, recurse=False
):
if param not in params_set:
continue
if param in shared_param_memo: # shared reference
@ -553,7 +557,8 @@ class FlatParamHandle:
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad
param, extension = _ext_pre_flatten_transform(param)
transform_t, extension = _ext_pre_flatten_transform(param)
param = cast(nn.Parameter, transform_t)
param_extensions.append(extension)
shared_param_memo[param] = (submodule, submodule_name, param_name)
params_to_flatten.append(param)