mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP][Reland] Implement local_state_dict and load_local_state_dict
1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`. 2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict(). ghstack-source-id: 149625958 Differential Revision: [D34383925](https://our.internmc.facebook.com/intern/diff/D34383925/) [ghstack-poisoned]
This commit is contained in:
parent
4bb27ae7d3
commit
782ee6c7e7
|
|
@ -198,7 +198,7 @@ class TestFlattenParams(TestCase):
|
|||
expected,
|
||||
msg=f"{flat_p.shard_metadata()}, {expected}",
|
||||
)
|
||||
self.assertEqual(flat_p._num_padded, kwargs["num_padded"])
|
||||
self.assertEqual(flat_p.num_padded, kwargs["num_padded"])
|
||||
|
||||
_test(
|
||||
kwargs={"start": -1, "end": -1, "num_padded": 0},
|
||||
|
|
|
|||
138
test/distributed/fsdp/test_fsdp_state_dict.py
Normal file
138
test/distributed/fsdp/test_fsdp_state_dict.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
get_full_params,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
INNER_SHAPE = [4, 4]
|
||||
OUTER_SHAPE = [4, 5]
|
||||
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, wrap_fsdp):
|
||||
super().__init__()
|
||||
self.inner = Linear(*INNER_SHAPE)
|
||||
if wrap_fsdp:
|
||||
self.inner = FSDP(self.inner)
|
||||
self.outer = Linear(*OUTER_SHAPE)
|
||||
|
||||
def forward(self, x):
|
||||
# Forward twice.
|
||||
i = self.inner(x)
|
||||
j = self.inner(x)
|
||||
return self.outer(i + j)
|
||||
|
||||
|
||||
class TestFSDPStateDict(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
def _initialize_model(self, wrap_fsdp: bool):
|
||||
# keep everything deterministic for input data
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = Model(wrap_fsdp).cuda()
|
||||
if wrap_fsdp:
|
||||
model = FSDP(model)
|
||||
else:
|
||||
model = DistributedDataParallel(model, device_ids=[self.rank])
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _state_dict(model: Module, state_dict_type: str):
|
||||
return getattr(model, state_dict_type)()
|
||||
|
||||
@staticmethod
|
||||
def _load_state_dict(
|
||||
model: Module, state_dict_type: str, state_dict: Dict[str, Any]
|
||||
):
|
||||
getattr(model, f"load_{state_dict_type}")(state_dict)
|
||||
|
||||
def _dist_train(
|
||||
self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False
|
||||
):
|
||||
# TODO: Move this test to common_fsdp.
|
||||
model = self._initialize_model(wrap_fsdp)
|
||||
optim = SGD(model.parameters(), lr=0.1)
|
||||
|
||||
in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
|
||||
for _ in range(3):
|
||||
out = model(in_data)
|
||||
out.sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
if wrap_fsdp:
|
||||
blank_model = FSDP(Model(True).cuda())
|
||||
if with_context:
|
||||
state_dict_type = {
|
||||
"full_state_dict": StateDictType.FULL_STATE_DICT,
|
||||
"local_state_dict": StateDictType.LOCAL_STATE_DICT,
|
||||
"sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
|
||||
}[state_dict_type]
|
||||
with model.state_dict_type(state_dict_type):
|
||||
state_dict = model.state_dict()
|
||||
with blank_model.state_dict_type(state_dict_type):
|
||||
blank_model.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict = self._state_dict(model, state_dict_type)
|
||||
self._load_state_dict(blank_model, state_dict_type, state_dict)
|
||||
get_full_params(blank_model)
|
||||
model = blank_model
|
||||
|
||||
return list(model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("state_dict_type", ["local_state_dict"])
|
||||
def test_state_dict_save_load_flow(self, state_dict_type):
|
||||
fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type)
|
||||
fsdp_params_using_context = self._dist_train(
|
||||
wrap_fsdp=True, state_dict_type=state_dict_type, with_context=True
|
||||
)
|
||||
ddp_params = self._dist_train(wrap_fsdp=False)
|
||||
self.assertEqual(ddp_params, fsdp_params)
|
||||
self.assertEqual(ddp_params, fsdp_params_using_context)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("state_dict_type", ["local_state_dict"])
|
||||
def test_fsdp_state_dict_keys(self, state_dict_type):
|
||||
state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
|
||||
if state_dict_type == "local_state_dict":
|
||||
self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys())
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestFSDPStateDict)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -35,11 +35,10 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer):
|
||||
model = FSDP(
|
||||
nn.Sequential(
|
||||
FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
|
||||
)
|
||||
nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))
|
||||
).cuda(cls.rank)
|
||||
|
||||
# set the value
|
||||
|
|
@ -64,6 +63,7 @@ def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_ou
|
|||
else:
|
||||
cls.assertEqual(p.cpu()[0], cls.rank + 2)
|
||||
|
||||
|
||||
class TestSummonFullParamsNoShard(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -84,6 +84,7 @@ class TestSummonFullParamsNoShard(FSDPTest):
|
|||
modify_outer,
|
||||
)
|
||||
|
||||
|
||||
class TestSummonFullParams(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -105,10 +106,7 @@ class TestSummonFullParams(FSDPTest):
|
|||
@parametrize("modify_outer", [True, False])
|
||||
def test_summon_full_param_writeback(self, writeback, cpu_offload, modify_outer):
|
||||
return _run_test_summon_full_param_writeback(
|
||||
self,
|
||||
writeback,
|
||||
cpu_offload,
|
||||
modify_outer
|
||||
self, writeback, cpu_offload, modify_outer
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .flatten_params_wrapper import FlatParameter
|
||||
from .fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
from .fully_sharded_data_parallel import CPUOffload
|
||||
from .fully_sharded_data_parallel import StateDictType
|
||||
|
|
|
|||
|
|
@ -18,14 +18,60 @@ from typing import (
|
|||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from .utils import _replace_by_prefix
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import OrderedDict # noqa: F401
|
||||
|
||||
ParamOffset = Tuple[int, int]
|
||||
SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
|
||||
FLAT_PARAM = "flat_param"
|
||||
FPW_MODULE = "_fpw_module"
|
||||
|
||||
|
||||
def _post_state_dict_hook(
|
||||
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
|
||||
) -> "OrderedDict[str, Tensor]":
|
||||
"""
|
||||
_post_state_dict_hook() is called after the state_dict() is executed
|
||||
and before returning the state_dict to the users.
|
||||
This API post-processes the keys of the state_dict to remove the
|
||||
FlattenParamsWrapper internal prefix.
|
||||
"""
|
||||
# Move everything from FPW_MODULE up one level.
|
||||
_replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
|
||||
return state_dict
|
||||
|
||||
|
||||
def _pre_load_state_dict_hook(
|
||||
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> None:
|
||||
"""
|
||||
_post_state_dict_hook() is called before the _load_from_state_dict() is
|
||||
This API pre-processes the keys of the state_dict to add the
|
||||
FlattenParamsWrapper internal prefix
|
||||
"""
|
||||
# Push everything down to FPW_MODULE level.
|
||||
_replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
|
||||
# The flat_param_* keys actually needs to move one level up.
|
||||
flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(flat_param_key):
|
||||
last_part = k.split(".")[-1]
|
||||
assert last_part.startswith(
|
||||
FLAT_PARAM
|
||||
), f"Expected key to contain flat_param, but key name is {k}"
|
||||
_replace_by_prefix(state_dict, k, prefix + last_part)
|
||||
|
||||
|
||||
class ParamInfo(NamedTuple):
|
||||
|
|
@ -98,10 +144,13 @@ class FlatParameter(nn.Parameter):
|
|||
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
|
||||
self._is_sharded = False
|
||||
self._param_numels = [p.numel() for p in params]
|
||||
assert self.numel() <= sum(self._param_numels), (
|
||||
# The total element numbers. This is equal to the summation of the
|
||||
# ``numel()`` of all the parameters.
|
||||
self.full_numel = sum(self._param_numels)
|
||||
assert self.numel() <= self.full_numel, (
|
||||
"Parameter numbers mismatched. "
|
||||
f"The number of elements in FlatParameter: {self.numel()} vs. "
|
||||
f"the number of elements in original parameters: {sum(self._param_numels)}."
|
||||
f"the number of elements in original parameters: {self.full_numel}."
|
||||
)
|
||||
# The shapes of each individual parameter.
|
||||
self._param_shapes = [p.size() for p in params]
|
||||
|
|
@ -124,7 +173,7 @@ class FlatParameter(nn.Parameter):
|
|||
(0, numel) for numel in self._param_numels
|
||||
]
|
||||
# The number of padding elements.
|
||||
self._num_padded = 0
|
||||
self.num_padded = 0
|
||||
|
||||
def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
|
||||
assert self._is_sharded
|
||||
|
|
@ -133,8 +182,8 @@ class FlatParameter(nn.Parameter):
|
|||
f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
|
||||
)
|
||||
_shard_size = end - start + 1
|
||||
self._num_padded = num_padded
|
||||
if self._num_padded > _shard_size:
|
||||
self.num_padded = num_padded
|
||||
if self.num_padded > _shard_size:
|
||||
raise ValueError("The number of padding is larger than the shard size.")
|
||||
self._sharded_param_offsets.clear()
|
||||
|
||||
|
|
@ -163,13 +212,13 @@ class FlatParameter(nn.Parameter):
|
|||
) -> Iterator[Tensor]:
|
||||
"""Return a generator of views that map to the original parameters."""
|
||||
# Note, self.data could be sharded, so its numel is <= to the sum.
|
||||
assert self.data.numel() <= sum(
|
||||
self._param_numels
|
||||
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
|
||||
assert (
|
||||
self.data.numel() <= self.full_numel
|
||||
), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
|
||||
data = external_data if external_data is not None else self
|
||||
if data.numel() != sum(self._param_numels):
|
||||
if data.numel() != self.full_numel:
|
||||
raise ValueError(
|
||||
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
|
||||
f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
|
||||
)
|
||||
return (
|
||||
t.view(s)
|
||||
|
|
@ -252,6 +301,15 @@ class FlattenParamsWrapper(nn.Module):
|
|||
self._orig_flat_param: List[Optional[FlatParameter]] = [None]
|
||||
self._flatten_params()
|
||||
|
||||
# Sanity check for the string constants.
|
||||
assert getattr(self, FPW_MODULE) is self._fpw_module
|
||||
assert getattr(self, FLAT_PARAM) is self.flat_param
|
||||
|
||||
# Register hook to be called after state_dict() to remove the
|
||||
# "_fpw_module." prefix and before load_state_dict() to add it back.
|
||||
self._register_state_dict_hook(_post_state_dict_hook)
|
||||
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
|
||||
|
||||
@property
|
||||
def module(self) -> Any:
|
||||
"""Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Generator,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
|
|
@ -24,17 +25,28 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed._sharded_tensor import (
|
||||
init_from_local_shards,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper
|
||||
from .utils import _apply_to_tensors
|
||||
from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper, FLAT_PARAM
|
||||
from .utils import (
|
||||
_apply_to_tensors,
|
||||
_replace_by_prefix,
|
||||
)
|
||||
from .wrap import _recursive_wrap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import OrderedDict # noqa: F401
|
||||
|
||||
|
||||
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffload:
|
||||
"""
|
||||
|
|
@ -98,6 +110,31 @@ class TrainingState_(Enum):
|
|||
SUMMON_FULL_PARAMS = auto()
|
||||
|
||||
|
||||
class StateDictType(Enum):
|
||||
"""
|
||||
This enum indicates that which type of ``state_dict`` the FSDP module is
|
||||
currently processing (returning or loading).
|
||||
The default value should be FULL_STATE_DICT to comply the PyTorch convention.
|
||||
..note::
|
||||
FSDP currently supports three types of ``state_dict``:
|
||||
1. ``state_dict/load_state_dict`: this pair of APIs return and load
|
||||
the non-sharded, unflattened parameters. The semantics is the
|
||||
same as using DDP.
|
||||
2. ``local_state_dict/load_local_state``: this pair of APIs return
|
||||
and load local sharded, flattened parameters. The values returned
|
||||
by ``local_state_dict`` can be directly used by FSDP and is only
|
||||
meaningful to FSDP (because parameters are flattened).
|
||||
3. ``sharded_state_dict/load_sharded_state_dict``: this pair of APIs
|
||||
return and load sharded, unflattened parameters. The ``state_dict``
|
||||
return by ``sharded_state_dict`` can be used by all other parallel
|
||||
schemes (resharding may be required).
|
||||
"""
|
||||
|
||||
FULL_STATE_DICT = auto()
|
||||
LOCAL_STATE_DICT = auto()
|
||||
SHARDED_STATE_DICT = auto()
|
||||
|
||||
|
||||
class FullyShardedDataParallel(nn.Module):
|
||||
"""
|
||||
A wrapper for sharding Module parameters across data parallel workers. This
|
||||
|
|
@ -244,6 +281,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper(
|
||||
module, param_list=params
|
||||
)
|
||||
assert getattr(self, FSDP_WRAPPED_MODULE) is self._fsdp_wrapped_module
|
||||
del module # free original module in case it helps garbage collection
|
||||
if self._fsdp_wrapped_module.flat_param is not None:
|
||||
self.params = [self._fsdp_wrapped_module.flat_param]
|
||||
|
|
@ -268,6 +306,29 @@ class FullyShardedDataParallel(nn.Module):
|
|||
# Enum to indicate if we're in the forward/backward pass, idle, etc.
|
||||
self.training_state = TrainingState_.IDLE
|
||||
|
||||
self._state_dict_type = StateDictType.FULL_STATE_DICT
|
||||
|
||||
# FSDP currently provides three different state_dicts. The actual
|
||||
# state_dict that will be saved/loaded is decided by
|
||||
# self._state_dict_type. And the main logic of each state_dict is
|
||||
# implemented in the hook. Therefore, for each hook (post-save and
|
||||
# pre-load), there is a dispatcher dictionary to dispatch the execution
|
||||
# flow to the correct implementation.
|
||||
self._register_state_dict_hook(self._post_state_dict_hook)
|
||||
self._post_state_dict_hook_fn = {
|
||||
StateDictType.FULL_STATE_DICT: self._full_post_state_dict_hook,
|
||||
StateDictType.LOCAL_STATE_DICT: self._local_post_state_dict_hook,
|
||||
StateDictType.SHARDED_STATE_DICT: self._sharded_post_state_dict_hook,
|
||||
}
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._pre_load_state_dict_hook, with_module=True
|
||||
)
|
||||
self._pre_load_state_dict_hook_fn = {
|
||||
StateDictType.FULL_STATE_DICT: self._full_pre_load_state_dict_hook,
|
||||
StateDictType.LOCAL_STATE_DICT: self._local_pre_load_state_dict_hook,
|
||||
StateDictType.SHARDED_STATE_DICT: self._sharded_pre_load_state_dict_hook,
|
||||
}
|
||||
|
||||
# Flag to guard against preparing gradients multiple times per backward pass.
|
||||
self._pre_backward_hook_has_run = False
|
||||
# Used for prefetching all gather full params in post backward hook
|
||||
|
|
@ -679,6 +740,227 @@ class FullyShardedDataParallel(nn.Module):
|
|||
else:
|
||||
return False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
|
||||
"""
|
||||
A context manager to set the state_dict_type of this FSDP module and
|
||||
its descendant FSDP modules.
|
||||
.. note:: This API should be called for only the root FSDP module.
|
||||
.. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
|
||||
|
||||
Args:
|
||||
state_dict_type (StateDictType): the desired state_dict_type to set.
|
||||
"""
|
||||
self._lazy_init()
|
||||
if not self._is_root:
|
||||
raise RuntimeError(
|
||||
f"state_dict_type context manager can only be called from the root FSDP module. {self._is_root}"
|
||||
)
|
||||
prev_state_dict_type = self._state_dict_type
|
||||
for module in self.modules():
|
||||
if isinstance(module, FullyShardedDataParallel):
|
||||
if module._state_dict_type != prev_state_dict_type:
|
||||
raise RuntimeError(
|
||||
"All FSDP module should the same state_dict_type."
|
||||
)
|
||||
module._state_dict_type = state_dict_type
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for module in self.modules():
|
||||
if isinstance(module, FullyShardedDataParallel):
|
||||
module._state_dict_type = prev_state_dict_type
|
||||
|
||||
def _full_post_state_dict_hook(
|
||||
self,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
prefix: str,
|
||||
) -> "OrderedDict[str, torch.Tensor]":
|
||||
return state_dict
|
||||
|
||||
def _local_post_state_dict_hook(
|
||||
self,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
prefix: str,
|
||||
) -> "OrderedDict[str, torch.Tensor]":
|
||||
"""
|
||||
This hook create a ShardedTensor from the local flat_param and replace
|
||||
the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
|
||||
will happen. The underlying storage is the same.
|
||||
"""
|
||||
_replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
|
||||
# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
|
||||
# value as the flat_param but it is a pure Tensor because
|
||||
# nn.Module.state_dict() will detach the parameter. Therefore, we need
|
||||
# to get flat_param from the FlattenParamsWrapper to get the metadata.
|
||||
flat_param = getattr(self.module, FLAT_PARAM, None)
|
||||
assert (
|
||||
flat_param is not None
|
||||
), "flat_param cannot be None when doing local_state_dict."
|
||||
|
||||
# Construct a ShardedTensor from the flat_param.
|
||||
full_numel = flat_param.full_numel
|
||||
shard_offset = flat_param.numel() * self.rank
|
||||
valid_data_size = flat_param.numel() - flat_param.num_padded
|
||||
if valid_data_size > 0 and flat_param.num_padded > 0:
|
||||
flat_param = flat_param.narrow(0, 0, valid_data_size)
|
||||
local_shards = [
|
||||
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
|
||||
]
|
||||
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
|
||||
local_shards, full_numel, process_group=self.process_group
|
||||
) # type: ignore[assignment]
|
||||
|
||||
return state_dict
|
||||
|
||||
def _sharded_post_state_dict_hook(
|
||||
self,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
prefix: str,
|
||||
) -> "OrderedDict[str, torch.Tensor]":
|
||||
raise NotImplementedError("Will be implemented in the next PRs.")
|
||||
|
||||
@staticmethod
|
||||
def _post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> "OrderedDict[str, torch.Tensor]":
|
||||
"""
|
||||
_post_state_dict_hook() is called after the state_dict() of this
|
||||
FSDP module is executed. ``self._state_dict_type`` is used to decide
|
||||
what postprocessing will be done.
|
||||
"""
|
||||
self = cast(FullyShardedDataParallel, module)
|
||||
return self._post_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
"""
|
||||
The entry point of all three FSDP state_dict APIs.
|
||||
``self._state_dict_type`` decides which code path to execute.
|
||||
|
||||
.. warning:: This needs to be called on all ranks, since synchronization
|
||||
primitives may be used.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
||||
assert getattr(self.module, FLAT_PARAM, None) is not None
|
||||
assert isinstance(self.module.flat_param, FlatParameter)
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
||||
raise NotImplementedError("Will be implemented in the next PRs.")
|
||||
else:
|
||||
raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
|
||||
|
||||
def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Returns the local state of the module. Parameters are flattened and
|
||||
sharded, so the resulting state_dict can only be loaded after the module
|
||||
has been wrapped with FSDP.
|
||||
"""
|
||||
with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
|
||||
return self.state_dict(*args, **kwargs)
|
||||
|
||||
def _full_pre_load_state_dict_hook(
|
||||
self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def _local_pre_load_state_dict_hook(
|
||||
self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""
|
||||
This hook finds the local flat_param for this FSDP module from the
|
||||
state_dict. The flat_param should be a ShardedTensor. This hook converts
|
||||
the ShardedTensor to a tensor. No copy happen unless padding is required.
|
||||
"""
|
||||
_replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.")
|
||||
key = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"
|
||||
load_tensor = state_dict[key]
|
||||
assert isinstance(
|
||||
load_tensor, ShardedTensor
|
||||
), "Tensors in local_state_dict should be ShardedTensor."
|
||||
|
||||
# Convert the ShardedTensor to a Tensor.
|
||||
shards = load_tensor.local_shards()
|
||||
assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
|
||||
load_tensor = cast(torch.Tensor, shards[0].tensor)
|
||||
|
||||
# Get the metada of the flat_param to decide whether to pad the loaded
|
||||
# tensor.
|
||||
flat_param = self.module.flat_param
|
||||
assert flat_param is not None
|
||||
if flat_param.num_padded not in (0, flat_param.numel()):
|
||||
assert load_tensor.numel() < flat_param.numel(), (
|
||||
f"Local shard size = {flat_param.numel()} and the tensor in "
|
||||
f"the state_dict is {load_tensor.numel()}."
|
||||
)
|
||||
load_tensor = F.pad(load_tensor, [0, flat_param.num_padded])
|
||||
state_dict[key] = load_tensor
|
||||
|
||||
def _sharded_pre_load_state_dict_hook(
|
||||
self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
raise NotImplementedError("Will be implemented in the next PRs.")
|
||||
|
||||
@staticmethod
|
||||
def _pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> None:
|
||||
"""
|
||||
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
|
||||
is called. ``self._state_dict_type`` is used to decide what preprocessing
|
||||
will be done.
|
||||
"""
|
||||
self = cast(FullyShardedDataParallel, module)
|
||||
self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
strict: bool = True,
|
||||
) -> NamedTuple:
|
||||
"""
|
||||
The entry point of all three FSDP load_state_dict APIs.
|
||||
``self._state_dict_type`` decides which code path to execute.
|
||||
|
||||
.. warning:: This needs to be called on all ranks, since synchronization
|
||||
primitives may be used.
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
||||
raise NotImplementedError("Will be implemented in the next PRs.")
|
||||
else:
|
||||
raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
|
||||
|
||||
def load_local_state_dict(
|
||||
self,
|
||||
state_dict: "OrderedDict[str, torch.Tensor]",
|
||||
strict: bool = True,
|
||||
) -> NamedTuple:
|
||||
"""
|
||||
Load states from a flatten, sharded state dictionary.
|
||||
"""
|
||||
with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
|
||||
return self.load_state_dict(state_dict, strict)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
self._lazy_init()
|
||||
|
||||
|
|
@ -1110,6 +1392,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
"""
|
||||
Gather all shards of params.
|
||||
"""
|
||||
self._lazy_init()
|
||||
|
||||
def update_p_data(output_tensor: torch.Tensor) -> None:
|
||||
"""
|
||||
|
|
@ -1246,8 +1529,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
until the eventual sync.
|
||||
"""
|
||||
self._lazy_init()
|
||||
assert self._is_root, \
|
||||
"`no_sync()` on inner FSDP instances is not supported"
|
||||
assert self._is_root, "`no_sync()` on inner FSDP instances is not supported"
|
||||
self._assert_state(TrainingState_.IDLE)
|
||||
old_flags = []
|
||||
for m in self.modules():
|
||||
|
|
@ -1258,9 +1540,10 @@ class FullyShardedDataParallel(nn.Module):
|
|||
yield
|
||||
finally:
|
||||
for m, old_flag in old_flags:
|
||||
assert not m._require_backward_grad_sync, \
|
||||
"`_require_backward_grad_sync` was incorrectly set to " \
|
||||
assert not m._require_backward_grad_sync, (
|
||||
"`_require_backward_grad_sync` was incorrectly set to "
|
||||
"`True` while in the `no_sync()` context manager"
|
||||
)
|
||||
m._require_backward_grad_sync = old_flag
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Dict, List, Tuple, Union, Any, Callable, Set
|
||||
from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import OrderedDict # noqa: F401
|
||||
|
||||
"""Useful functions to deal with tensor types with other python container types."""
|
||||
|
||||
|
|
@ -22,3 +24,27 @@ def _apply_to_tensors(
|
|||
return x
|
||||
|
||||
return apply(container)
|
||||
|
||||
|
||||
def _replace_by_prefix(
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
|
||||
old_prefix: str,
|
||||
new_prefix: str,
|
||||
) -> None:
|
||||
"""
|
||||
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
||||
|
||||
Usage::
|
||||
|
||||
state_dict = {"layer.xyz": torch.tensor(1)}
|
||||
replace_by_prefix_(state_dict, "layer.", "module.layer.")
|
||||
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
|
||||
"""
|
||||
if old_prefix == new_prefix:
|
||||
raise ValueError("old_prefix and new_prefix must be distinct")
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.startswith(old_prefix):
|
||||
continue
|
||||
new_key = new_prefix + key[len(old_prefix) :]
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user