[FSDP][Reland] Implement local_state_dict and load_local_state_dict

1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149625958

Differential Revision: [D34383925](https://our.internmc.facebook.com/intern/diff/D34383925/)

[ghstack-poisoned]
This commit is contained in:
Rohan Varma 2022-02-23 07:57:34 -08:00
parent 4bb27ae7d3
commit 782ee6c7e7
7 changed files with 531 additions and 27 deletions

View File

@ -198,7 +198,7 @@ class TestFlattenParams(TestCase):
expected,
msg=f"{flat_p.shard_metadata()}, {expected}",
)
self.assertEqual(flat_p._num_padded, kwargs["num_padded"])
self.assertEqual(flat_p.num_padded, kwargs["num_padded"])
_test(
kwargs={"start": -1, "end": -1, "num_padded": 0},

View File

@ -0,0 +1,138 @@
# Owner(s): ["oncall: distributed"]
import sys
from typing import Any, Dict
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.nn import Linear, Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
get_full_params,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
INNER_SHAPE = [4, 4]
OUTER_SHAPE = [4, 5]
class Model(Module):
def __init__(self, wrap_fsdp):
super().__init__()
self.inner = Linear(*INNER_SHAPE)
if wrap_fsdp:
self.inner = FSDP(self.inner)
self.outer = Linear(*OUTER_SHAPE)
def forward(self, x):
# Forward twice.
i = self.inner(x)
j = self.inner(x)
return self.outer(i + j)
class TestFSDPStateDict(FSDPTest):
@property
def world_size(self):
return 2
def _initialize_model(self, wrap_fsdp: bool):
# keep everything deterministic for input data
torch.manual_seed(0)
model = Model(wrap_fsdp).cuda()
if wrap_fsdp:
model = FSDP(model)
else:
model = DistributedDataParallel(model, device_ids=[self.rank])
return model
@staticmethod
def _state_dict(model: Module, state_dict_type: str):
return getattr(model, state_dict_type)()
@staticmethod
def _load_state_dict(
model: Module, state_dict_type: str, state_dict: Dict[str, Any]
):
getattr(model, f"load_{state_dict_type}")(state_dict)
def _dist_train(
self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False
):
# TODO: Move this test to common_fsdp.
model = self._initialize_model(wrap_fsdp)
optim = SGD(model.parameters(), lr=0.1)
in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
for _ in range(3):
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
if wrap_fsdp:
blank_model = FSDP(Model(True).cuda())
if with_context:
state_dict_type = {
"full_state_dict": StateDictType.FULL_STATE_DICT,
"local_state_dict": StateDictType.LOCAL_STATE_DICT,
"sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
}[state_dict_type]
with model.state_dict_type(state_dict_type):
state_dict = model.state_dict()
with blank_model.state_dict_type(state_dict_type):
blank_model.load_state_dict(state_dict)
else:
state_dict = self._state_dict(model, state_dict_type)
self._load_state_dict(blank_model, state_dict_type, state_dict)
get_full_params(blank_model)
model = blank_model
return list(model.parameters())
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", ["local_state_dict"])
def test_state_dict_save_load_flow(self, state_dict_type):
fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type)
fsdp_params_using_context = self._dist_train(
wrap_fsdp=True, state_dict_type=state_dict_type, with_context=True
)
ddp_params = self._dist_train(wrap_fsdp=False)
self.assertEqual(ddp_params, fsdp_params)
self.assertEqual(ddp_params, fsdp_params_using_context)
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", ["local_state_dict"])
def test_fsdp_state_dict_keys(self, state_dict_type):
state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
if state_dict_type == "local_state_dict":
self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys())
instantiate_parametrized_tests(TestFSDPStateDict)
if __name__ == "__main__":
run_tests()

View File

@ -1,8 +1,8 @@
# Owner(s): ["oncall: distributed"]
import itertools
from copy import deepcopy
import math
import sys
from copy import deepcopy
import torch
import torch.nn as nn
@ -35,11 +35,10 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer):
model = FSDP(
nn.Sequential(
FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
)
nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))
).cuda(cls.rank)
# set the value
@ -64,6 +63,7 @@ def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_ou
else:
cls.assertEqual(p.cpu()[0], cls.rank + 2)
class TestSummonFullParamsNoShard(FSDPTest):
@property
def world_size(self):
@ -84,6 +84,7 @@ class TestSummonFullParamsNoShard(FSDPTest):
modify_outer,
)
class TestSummonFullParams(FSDPTest):
@property
def world_size(self):
@ -105,10 +106,7 @@ class TestSummonFullParams(FSDPTest):
@parametrize("modify_outer", [True, False])
def test_summon_full_param_writeback(self, writeback, cpu_offload, modify_outer):
return _run_test_summon_full_param_writeback(
self,
writeback,
cpu_offload,
modify_outer
self, writeback, cpu_offload, modify_outer
)
@skip_if_lt_x_gpu(2)

View File

@ -1,3 +1,4 @@
from .flatten_params_wrapper import FlatParameter
from .fully_sharded_data_parallel import FullyShardedDataParallel
from .fully_sharded_data_parallel import CPUOffload
from .fully_sharded_data_parallel import StateDictType

View File

@ -18,14 +18,60 @@ from typing import (
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
import torch
import torch.nn as nn
from torch import Tensor
from .utils import _replace_by_prefix
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
ParamOffset = Tuple[int, int]
SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
FLAT_PARAM = "flat_param"
FPW_MODULE = "_fpw_module"
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, Tensor]":
"""
_post_state_dict_hook() is called after the state_dict() is executed
and before returning the state_dict to the users.
This API post-processes the keys of the state_dict to remove the
FlattenParamsWrapper internal prefix.
"""
# Move everything from FPW_MODULE up one level.
_replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"],
prefix: str,
*args: Any,
) -> None:
"""
_post_state_dict_hook() is called before the _load_from_state_dict() is
This API pre-processes the keys of the state_dict to add the
FlattenParamsWrapper internal prefix
"""
# Push everything down to FPW_MODULE level.
_replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
# The flat_param_* keys actually needs to move one level up.
flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
for k in list(state_dict.keys()):
if k.startswith(flat_param_key):
last_part = k.split(".")[-1]
assert last_part.startswith(
FLAT_PARAM
), f"Expected key to contain flat_param, but key name is {k}"
_replace_by_prefix(state_dict, k, prefix + last_part)
class ParamInfo(NamedTuple):
@ -98,10 +144,13 @@ class FlatParameter(nn.Parameter):
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
self._is_sharded = False
self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum(self._param_numels), (
# The total element numbers. This is equal to the summation of the
# ``numel()`` of all the parameters.
self.full_numel = sum(self._param_numels)
assert self.numel() <= self.full_numel, (
"Parameter numbers mismatched. "
f"The number of elements in FlatParameter: {self.numel()} vs. "
f"the number of elements in original parameters: {sum(self._param_numels)}."
f"the number of elements in original parameters: {self.full_numel}."
)
# The shapes of each individual parameter.
self._param_shapes = [p.size() for p in params]
@ -124,7 +173,7 @@ class FlatParameter(nn.Parameter):
(0, numel) for numel in self._param_numels
]
# The number of padding elements.
self._num_padded = 0
self.num_padded = 0
def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
assert self._is_sharded
@ -133,8 +182,8 @@ class FlatParameter(nn.Parameter):
f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
)
_shard_size = end - start + 1
self._num_padded = num_padded
if self._num_padded > _shard_size:
self.num_padded = num_padded
if self.num_padded > _shard_size:
raise ValueError("The number of padding is larger than the shard size.")
self._sharded_param_offsets.clear()
@ -163,13 +212,13 @@ class FlatParameter(nn.Parameter):
) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
assert (
self.data.numel() <= self.full_numel
), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
data = external_data if external_data is not None else self
if data.numel() != sum(self._param_numels):
if data.numel() != self.full_numel:
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
)
return (
t.view(s)
@ -252,6 +301,15 @@ class FlattenParamsWrapper(nn.Module):
self._orig_flat_param: List[Optional[FlatParameter]] = [None]
self._flatten_params()
# Sanity check for the string constants.
assert getattr(self, FPW_MODULE) is self._fpw_module
assert getattr(self, FLAT_PARAM) is self.flat_param
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
@property
def module(self) -> Any:
"""Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module

View File

@ -9,9 +9,10 @@ from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Generator,
NamedTuple,
Set,
Tuple,
Union,
@ -24,17 +25,28 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.distributed._sharded_tensor import (
init_from_local_shards,
Shard,
ShardedTensor,
)
from torch.distributed.distributed_c10d import _get_default_group
from torch.nn.parameter import Parameter
from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper
from .utils import _apply_to_tensors
from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper, FLAT_PARAM
from .utils import (
_apply_to_tensors,
_replace_by_prefix,
)
from .wrap import _recursive_wrap
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
@dataclass
class CPUOffload:
"""
@ -98,6 +110,31 @@ class TrainingState_(Enum):
SUMMON_FULL_PARAMS = auto()
class StateDictType(Enum):
"""
This enum indicates that which type of ``state_dict`` the FSDP module is
currently processing (returning or loading).
The default value should be FULL_STATE_DICT to comply the PyTorch convention.
..note::
FSDP currently supports three types of ``state_dict``:
1. ``state_dict/load_state_dict`: this pair of APIs return and load
the non-sharded, unflattened parameters. The semantics is the
same as using DDP.
2. ``local_state_dict/load_local_state``: this pair of APIs return
and load local sharded, flattened parameters. The values returned
by ``local_state_dict`` can be directly used by FSDP and is only
meaningful to FSDP (because parameters are flattened).
3. ``sharded_state_dict/load_sharded_state_dict``: this pair of APIs
return and load sharded, unflattened parameters. The ``state_dict``
return by ``sharded_state_dict`` can be used by all other parallel
schemes (resharding may be required).
"""
FULL_STATE_DICT = auto()
LOCAL_STATE_DICT = auto()
SHARDED_STATE_DICT = auto()
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
@ -244,6 +281,7 @@ class FullyShardedDataParallel(nn.Module):
self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper(
module, param_list=params
)
assert getattr(self, FSDP_WRAPPED_MODULE) is self._fsdp_wrapped_module
del module # free original module in case it helps garbage collection
if self._fsdp_wrapped_module.flat_param is not None:
self.params = [self._fsdp_wrapped_module.flat_param]
@ -268,6 +306,29 @@ class FullyShardedDataParallel(nn.Module):
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState_.IDLE
self._state_dict_type = StateDictType.FULL_STATE_DICT
# FSDP currently provides three different state_dicts. The actual
# state_dict that will be saved/loaded is decided by
# self._state_dict_type. And the main logic of each state_dict is
# implemented in the hook. Therefore, for each hook (post-save and
# pre-load), there is a dispatcher dictionary to dispatch the execution
# flow to the correct implementation.
self._register_state_dict_hook(self._post_state_dict_hook)
self._post_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: self._full_post_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: self._local_post_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: self._sharded_post_state_dict_hook,
}
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)
self._pre_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: self._full_pre_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: self._local_pre_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: self._sharded_pre_load_state_dict_hook,
}
# Flag to guard against preparing gradients multiple times per backward pass.
self._pre_backward_hook_has_run = False
# Used for prefetching all gather full params in post backward hook
@ -679,6 +740,227 @@ class FullyShardedDataParallel(nn.Module):
else:
return False
@contextlib.contextmanager
def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
"""
A context manager to set the state_dict_type of this FSDP module and
its descendant FSDP modules.
.. note:: This API should be called for only the root FSDP module.
.. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
Args:
state_dict_type (StateDictType): the desired state_dict_type to set.
"""
self._lazy_init()
if not self._is_root:
raise RuntimeError(
f"state_dict_type context manager can only be called from the root FSDP module. {self._is_root}"
)
prev_state_dict_type = self._state_dict_type
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
if module._state_dict_type != prev_state_dict_type:
raise RuntimeError(
"All FSDP module should the same state_dict_type."
)
module._state_dict_type = state_dict_type
try:
yield
finally:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._state_dict_type = prev_state_dict_type
def _full_post_state_dict_hook(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
) -> "OrderedDict[str, torch.Tensor]":
return state_dict
def _local_post_state_dict_hook(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
) -> "OrderedDict[str, torch.Tensor]":
"""
This hook create a ShardedTensor from the local flat_param and replace
the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
will happen. The underlying storage is the same.
"""
_replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
# value as the flat_param but it is a pure Tensor because
# nn.Module.state_dict() will detach the parameter. Therefore, we need
# to get flat_param from the FlattenParamsWrapper to get the metadata.
flat_param = getattr(self.module, FLAT_PARAM, None)
assert (
flat_param is not None
), "flat_param cannot be None when doing local_state_dict."
# Construct a ShardedTensor from the flat_param.
full_numel = flat_param.full_numel
shard_offset = flat_param.numel() * self.rank
valid_data_size = flat_param.numel() - flat_param.num_padded
if valid_data_size > 0 and flat_param.num_padded > 0:
flat_param = flat_param.narrow(0, 0, valid_data_size)
local_shards = [
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
]
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
local_shards, full_numel, process_group=self.process_group
) # type: ignore[assignment]
return state_dict
def _sharded_post_state_dict_hook(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
) -> "OrderedDict[str, torch.Tensor]":
raise NotImplementedError("Will be implemented in the next PRs.")
@staticmethod
def _post_state_dict_hook(
module: nn.Module,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
*args: Any,
) -> "OrderedDict[str, torch.Tensor]":
"""
_post_state_dict_hook() is called after the state_dict() of this
FSDP module is executed. ``self._state_dict_type`` is used to decide
what postprocessing will be done.
"""
self = cast(FullyShardedDataParallel, module)
return self._post_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
def state_dict(self, destination=None, prefix="", keep_vars=False):
"""
The entry point of all three FSDP state_dict APIs.
``self._state_dict_type`` decides which code path to execute.
.. warning:: This needs to be called on all ranks, since synchronization
primitives may be used.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
return super().state_dict(destination, prefix, keep_vars)
elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
assert getattr(self.module, FLAT_PARAM, None) is not None
assert isinstance(self.module.flat_param, FlatParameter)
return super().state_dict(destination, prefix, keep_vars)
elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
raise NotImplementedError("Will be implemented in the next PRs.")
else:
raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""
Returns the local state of the module. Parameters are flattened and
sharded, so the resulting state_dict can only be loaded after the module
has been wrapped with FSDP.
"""
with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
return self.state_dict(*args, **kwargs)
def _full_pre_load_state_dict_hook(
self,
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
prefix: str,
) -> None:
return
def _local_pre_load_state_dict_hook(
self,
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
prefix: str,
) -> None:
"""
This hook finds the local flat_param for this FSDP module from the
state_dict. The flat_param should be a ShardedTensor. This hook converts
the ShardedTensor to a tensor. No copy happen unless padding is required.
"""
_replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.")
key = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"
load_tensor = state_dict[key]
assert isinstance(
load_tensor, ShardedTensor
), "Tensors in local_state_dict should be ShardedTensor."
# Convert the ShardedTensor to a Tensor.
shards = load_tensor.local_shards()
assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
load_tensor = cast(torch.Tensor, shards[0].tensor)
# Get the metada of the flat_param to decide whether to pad the loaded
# tensor.
flat_param = self.module.flat_param
assert flat_param is not None
if flat_param.num_padded not in (0, flat_param.numel()):
assert load_tensor.numel() < flat_param.numel(), (
f"Local shard size = {flat_param.numel()} and the tensor in "
f"the state_dict is {load_tensor.numel()}."
)
load_tensor = F.pad(load_tensor, [0, flat_param.num_padded])
state_dict[key] = load_tensor
def _sharded_pre_load_state_dict_hook(
self,
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
prefix: str,
) -> None:
raise NotImplementedError("Will be implemented in the next PRs.")
@staticmethod
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
prefix: str,
*args: Any,
) -> None:
"""
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
is called. ``self._state_dict_type`` is used to decide what preprocessing
will be done.
"""
self = cast(FullyShardedDataParallel, module)
self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
def load_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> NamedTuple:
"""
The entry point of all three FSDP load_state_dict APIs.
``self._state_dict_type`` decides which code path to execute.
.. warning:: This needs to be called on all ranks, since synchronization
primitives may be used.
"""
torch.cuda.synchronize()
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
return super().load_state_dict(state_dict, strict)
elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
return super().load_state_dict(state_dict, strict)
elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
raise NotImplementedError("Will be implemented in the next PRs.")
else:
raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
def load_local_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> NamedTuple:
"""
Load states from a flatten, sharded state dictionary.
"""
with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
return self.load_state_dict(state_dict, strict)
def forward(self, *args: Any, **kwargs: Any) -> Any:
self._lazy_init()
@ -1110,6 +1392,7 @@ class FullyShardedDataParallel(nn.Module):
"""
Gather all shards of params.
"""
self._lazy_init()
def update_p_data(output_tensor: torch.Tensor) -> None:
"""
@ -1246,8 +1529,7 @@ class FullyShardedDataParallel(nn.Module):
until the eventual sync.
"""
self._lazy_init()
assert self._is_root, \
"`no_sync()` on inner FSDP instances is not supported"
assert self._is_root, "`no_sync()` on inner FSDP instances is not supported"
self._assert_state(TrainingState_.IDLE)
old_flags = []
for m in self.modules():
@ -1258,9 +1540,10 @@ class FullyShardedDataParallel(nn.Module):
yield
finally:
for m, old_flag in old_flags:
assert not m._require_backward_grad_sync, \
"`_require_backward_grad_sync` was incorrectly set to " \
assert not m._require_backward_grad_sync, (
"`_require_backward_grad_sync` was incorrectly set to "
"`True` while in the `no_sync()` context manager"
)
m._require_backward_grad_sync = old_flag

View File

@ -1,7 +1,9 @@
from typing import Dict, List, Tuple, Union, Any, Callable, Set
from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
"""Useful functions to deal with tensor types with other python container types."""
@ -22,3 +24,27 @@ def _apply_to_tensors(
return x
return apply(container)
def _replace_by_prefix(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
old_prefix: str,
new_prefix: str,
) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix) :]
state_dict[new_key] = state_dict[key]
del state_dict[key]