mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add fqn_modifier at loading_state_dict and unit test (#146557)
In Fusion model, users might change the state_dict keys by state_dict_hook
The load_state_dict APIs here won't call model.state_dict() so that the hooks won't be called to change the keys, causing the mismatch between fqn and state_dict keys.
The PR here suggests users to add how they would change the state_dict key prefix (they can name it, here we call "fqn_modifiers") by default
During loading state_dict, we have the prefix change during getting fqn so that they can be processed same as through state_dict hook.
For example:
There's a state_dict_hook:
```
def _state_dict_hook(self, destination, prefix, keep_vars):
"""Remove "embedding" from the original embedding in the state_dict
name. This keeps the orginal state dict name for the embedding
from before fusing with the FusionEmbedding.
[!Note] This update changes the order of the OrderedDict
"""
key = prefix + "embedding.weight"
new_key = prefix + "weight"
destination[new_key] = destination[key]
del destination[key]
```
In the dsd after this PR, we would skip "embedding." before "weight" if find the "fqn_modifiers" attribute at that module
```
def fqn_modifiers(self) -> Dict[str, str]:
return {
"weight": "embedding",
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146557
Approved by: https://github.com/fegin
This commit is contained in:
parent
7622e29a37
commit
a21a123fd5
|
|
@ -47,7 +47,12 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
MultiProcessTestCase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
|
||||
from torch.testing._internal.distributed.common_state_dict import (
|
||||
FusionEmbedding,
|
||||
FusionEmbeddingWithHook,
|
||||
FusionEmbeddingWithModifier,
|
||||
VerifyStateDictMixin,
|
||||
)
|
||||
from torch.utils._pytree import tree_all, tree_all_only
|
||||
|
||||
|
||||
|
|
@ -919,6 +924,20 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_state_dict_with_hook_on_keys(self) -> None:
|
||||
with torch.device("meta"):
|
||||
metamodel = FusionEmbedding(4, 4, 4)
|
||||
with torch.device("cuda"):
|
||||
gpumodel = FusionEmbeddingWithHook(4, 4, 4)
|
||||
gpumodel_state_dict = get_model_state_dict(gpumodel)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
set_model_state_dict(metamodel, gpumodel_state_dict)
|
||||
with torch.device("meta"):
|
||||
metamodel_modified = FusionEmbeddingWithModifier(4, 4, 4)
|
||||
set_model_state_dict(metamodel_modified, gpumodel_state_dict)
|
||||
|
||||
|
||||
class TestNoComm(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ class StateDictOptions:
|
|||
strict: bool = True
|
||||
broadcast_from_rank0: bool = False
|
||||
flatten_optimizer_state_dict: bool = False
|
||||
dsd_fqn_modifiers: str = "_fqn_modifiers"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -155,6 +156,7 @@ class _StateDictInfo(StateDictOptions):
|
|||
def _get_fqns(
|
||||
model: nn.Module,
|
||||
name: str,
|
||||
dsd_fqn_modifiers: str = "_fqn_modifiers",
|
||||
skip_ddp_prefix: bool = True,
|
||||
skip_compiler_prefix: bool = True,
|
||||
) -> FQNS_T:
|
||||
|
|
@ -204,6 +206,14 @@ def _get_fqns(
|
|||
if not skip_compiler_prefix:
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
else:
|
||||
# In some modeuls, _fqn_modifiers would not shown in the state_dict keys,
|
||||
# skip them in the fqn to ensure load stat dict successfully for them.
|
||||
if hasattr(curr_obj, dsd_fqn_modifiers):
|
||||
if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(
|
||||
curr_obj_name
|
||||
):
|
||||
if hasattr(curr_obj, removed_fqn):
|
||||
curr_obj = getattr(curr_obj, removed_fqn)
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
|
||||
if i != len(obj_names) - 1:
|
||||
|
|
@ -218,7 +228,7 @@ class _EXTRA_STATE:
|
|||
pass
|
||||
|
||||
|
||||
def _iterate_valid_model_state(model):
|
||||
def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"):
|
||||
visited_modules: set[nn.Module] = set()
|
||||
|
||||
def recurse(module: nn.Module, curr_fqn: str) -> Generator:
|
||||
|
|
@ -228,7 +238,16 @@ def _iterate_valid_model_state(model):
|
|||
for name, submodule in module.named_children():
|
||||
if submodule in visited_modules:
|
||||
continue
|
||||
new_fqn = f"{curr_fqn}{name}"
|
||||
# if user have state_dict_hooks in their model, they can add the state_dict key changes
|
||||
# at dsd_fqn_modifiers in input to align with the function of state_dict_hook
|
||||
if (
|
||||
hasattr(module, dsd_fqn_modifiers)
|
||||
and name in getattr(module, dsd_fqn_modifiers)().values()
|
||||
):
|
||||
# skip _fqn_modifiers here thus remove the last `.` added
|
||||
new_fqn = curr_fqn[:-1]
|
||||
else:
|
||||
new_fqn = f"{curr_fqn}{name}"
|
||||
yield from recurse(submodule, new_fqn)
|
||||
|
||||
for name, obj in chain(
|
||||
|
|
@ -527,10 +546,14 @@ def _load_model_state_dict(
|
|||
return _IncompatibleKeys({}, {})
|
||||
|
||||
local_state_dict = {}
|
||||
for key, value in _iterate_valid_model_state(model):
|
||||
fqns = _get_fqns(model, key)
|
||||
for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers):
|
||||
fqns = _get_fqns(model, key, info.dsd_fqn_modifiers)
|
||||
fqns_with_prefix = _get_fqns(
|
||||
model, key, skip_ddp_prefix=False, skip_compiler_prefix=False
|
||||
model,
|
||||
key,
|
||||
info.dsd_fqn_modifiers,
|
||||
skip_ddp_prefix=False,
|
||||
skip_compiler_prefix=False,
|
||||
)
|
||||
|
||||
for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -120,3 +120,51 @@ class VerifyStateDictMixin:
|
|||
optim_state_dict=new_dist_osd,
|
||||
)
|
||||
self.assertEqual(optim.state_dict(), new_optim.state_dict())
|
||||
|
||||
|
||||
class FusionEmbedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
||||
self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
|
||||
|
||||
|
||||
class FusionEmbeddingWithHook(nn.Module):
|
||||
def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
||||
self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
|
||||
self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook)
|
||||
self._register_load_state_dict_pre_hook(
|
||||
FusionEmbeddingWithHook._load_state_dict_hook, with_module=True
|
||||
)
|
||||
|
||||
def _state_dict_hook(self, destination, prefix, keep_vars):
|
||||
"""Remove "embedding" from the original embedding in the state_dict
|
||||
name. This keeps the orginal state dict name for the embedding
|
||||
from before fusing with the FusionEmbedding.
|
||||
"""
|
||||
key = prefix + "embedding.weight"
|
||||
new_key = prefix + "weight"
|
||||
destination[new_key] = destination[key]
|
||||
del destination[key]
|
||||
|
||||
def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
|
||||
"""Apply extra "embedding" prefix to the state_dict key to
|
||||
account for the FusionEmbedding wrapping.
|
||||
"""
|
||||
if state_dict:
|
||||
key = prefix + "weight"
|
||||
new_key = prefix + "embedding.weight"
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
|
||||
class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
|
||||
# _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
|
||||
# keys, they need to provide a mapping from the new key to the original key. This is used to ensure
|
||||
# consistency between the state_dict keys and fqn.
|
||||
def _fqn_modifiers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"weight": "embedding",
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user