Add fqn_modifier at loading_state_dict and unit test (#146557)

In Fusion model, users might change the state_dict keys by state_dict_hook
The load_state_dict APIs here won't call model.state_dict() so that the hooks won't be called to change the keys, causing the mismatch between fqn and state_dict keys.

The PR here suggests users to add how they would change the state_dict key prefix (they can name it, here we call "fqn_modifiers") by default
During loading state_dict, we have the prefix change during getting fqn so that they can be processed same as through state_dict hook.

For example:
There's a state_dict_hook:

```
def _state_dict_hook(self, destination, prefix, keep_vars):
    """Remove "embedding" from the original embedding in the state_dict
    name. This keeps the orginal state dict name for the embedding
    from before fusing with the FusionEmbedding.

    [!Note] This update changes the order of the OrderedDict
    """
    key = prefix + "embedding.weight"
    new_key = prefix + "weight"
    destination[new_key] = destination[key]
    del destination[key]
```

In the dsd after this PR, we would skip "embedding." before "weight" if find the "fqn_modifiers" attribute at that module
```
def fqn_modifiers(self) -> Dict[str, str]:
    return {
        "weight": "embedding",
    }
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146557
Approved by: https://github.com/fegin
This commit is contained in:
mori360 2025-02-18 22:54:41 +00:00 committed by PyTorch MergeBot
parent 7622e29a37
commit a21a123fd5
3 changed files with 97 additions and 7 deletions

View File

@ -47,7 +47,12 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
MultiProcessTestCase,
with_comms,
)
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
from torch.testing._internal.distributed.common_state_dict import (
FusionEmbedding,
FusionEmbeddingWithHook,
FusionEmbeddingWithModifier,
VerifyStateDictMixin,
)
from torch.utils._pytree import tree_all, tree_all_only
@ -919,6 +924,20 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
),
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_state_dict_with_hook_on_keys(self) -> None:
with torch.device("meta"):
metamodel = FusionEmbedding(4, 4, 4)
with torch.device("cuda"):
gpumodel = FusionEmbeddingWithHook(4, 4, 4)
gpumodel_state_dict = get_model_state_dict(gpumodel)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(metamodel, gpumodel_state_dict)
with torch.device("meta"):
metamodel_modified = FusionEmbeddingWithModifier(4, 4, 4)
set_model_state_dict(metamodel_modified, gpumodel_state_dict)
class TestNoComm(MultiProcessTestCase):
def setUp(self) -> None:

View File

@ -134,6 +134,7 @@ class StateDictOptions:
strict: bool = True
broadcast_from_rank0: bool = False
flatten_optimizer_state_dict: bool = False
dsd_fqn_modifiers: str = "_fqn_modifiers"
@dataclass
@ -155,6 +156,7 @@ class _StateDictInfo(StateDictOptions):
def _get_fqns(
model: nn.Module,
name: str,
dsd_fqn_modifiers: str = "_fqn_modifiers",
skip_ddp_prefix: bool = True,
skip_compiler_prefix: bool = True,
) -> FQNS_T:
@ -204,6 +206,14 @@ def _get_fqns(
if not skip_compiler_prefix:
fqn_obj_names.append(curr_obj_name)
else:
# In some modeuls, _fqn_modifiers would not shown in the state_dict keys,
# skip them in the fqn to ensure load stat dict successfully for them.
if hasattr(curr_obj, dsd_fqn_modifiers):
if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(
curr_obj_name
):
if hasattr(curr_obj, removed_fqn):
curr_obj = getattr(curr_obj, removed_fqn)
fqn_obj_names.append(curr_obj_name)
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
if i != len(obj_names) - 1:
@ -218,7 +228,7 @@ class _EXTRA_STATE:
pass
def _iterate_valid_model_state(model):
def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"):
visited_modules: set[nn.Module] = set()
def recurse(module: nn.Module, curr_fqn: str) -> Generator:
@ -228,7 +238,16 @@ def _iterate_valid_model_state(model):
for name, submodule in module.named_children():
if submodule in visited_modules:
continue
new_fqn = f"{curr_fqn}{name}"
# if user have state_dict_hooks in their model, they can add the state_dict key changes
# at dsd_fqn_modifiers in input to align with the function of state_dict_hook
if (
hasattr(module, dsd_fqn_modifiers)
and name in getattr(module, dsd_fqn_modifiers)().values()
):
# skip _fqn_modifiers here thus remove the last `.` added
new_fqn = curr_fqn[:-1]
else:
new_fqn = f"{curr_fqn}{name}"
yield from recurse(submodule, new_fqn)
for name, obj in chain(
@ -527,10 +546,14 @@ def _load_model_state_dict(
return _IncompatibleKeys({}, {})
local_state_dict = {}
for key, value in _iterate_valid_model_state(model):
fqns = _get_fqns(model, key)
for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers):
fqns = _get_fqns(model, key, info.dsd_fqn_modifiers)
fqns_with_prefix = _get_fqns(
model, key, skip_ddp_prefix=False, skip_compiler_prefix=False
model,
key,
info.dsd_fqn_modifiers,
skip_ddp_prefix=False,
skip_compiler_prefix=False,
)
for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):

View File

@ -4,7 +4,7 @@
import copy
from itertools import chain
from typing import Any
from typing import Any, Dict
import torch
import torch.nn as nn
@ -120,3 +120,51 @@ class VerifyStateDictMixin:
optim_state_dict=new_dist_osd,
)
self.assertEqual(optim.state_dict(), new_optim.state_dict())
class FusionEmbedding(nn.Module):
def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
class FusionEmbeddingWithHook(nn.Module):
def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook)
self._register_load_state_dict_pre_hook(
FusionEmbeddingWithHook._load_state_dict_hook, with_module=True
)
def _state_dict_hook(self, destination, prefix, keep_vars):
"""Remove "embedding" from the original embedding in the state_dict
name. This keeps the orginal state dict name for the embedding
from before fusing with the FusionEmbedding.
"""
key = prefix + "embedding.weight"
new_key = prefix + "weight"
destination[new_key] = destination[key]
del destination[key]
def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
"""Apply extra "embedding" prefix to the state_dict key to
account for the FusionEmbedding wrapping.
"""
if state_dict:
key = prefix + "weight"
new_key = prefix + "embedding.weight"
state_dict[new_key] = state_dict[key]
del state_dict[key]
class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
# _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
# keys, they need to provide a mapping from the new key to the original key. This is used to ensure
# consistency between the state_dict keys and fqn.
def _fqn_modifiers(self) -> Dict[str, str]:
return {
"weight": "embedding",
}