mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025 Approved by: https://github.com/fegin
This commit is contained in:
parent
2aa806608b
commit
c43e88665a
|
|
@ -11,10 +11,9 @@ from typing import List
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.distributed.fsdp._utils import _apply_to_tensors
|
|
||||||
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
|
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
|
||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||||
from torch.distributed.utils import _replace_by_prefix
|
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
|
|
|
||||||
|
|
@ -26,11 +26,7 @@ from torch.distributed.fsdp._common_utils import (
|
||||||
TrainingState,
|
TrainingState,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
|
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
|
||||||
from torch.distributed.fsdp._utils import (
|
from torch.distributed.fsdp._utils import _no_dispatch_record_stream
|
||||||
_apply_to_tensors,
|
|
||||||
_no_dispatch_record_stream,
|
|
||||||
p_assert,
|
|
||||||
)
|
|
||||||
from torch.distributed.fsdp.api import BackwardPrefetch
|
from torch.distributed.fsdp.api import BackwardPrefetch
|
||||||
from torch.distributed.fsdp.flat_param import (
|
from torch.distributed.fsdp.flat_param import (
|
||||||
_HandlesKey,
|
_HandlesKey,
|
||||||
|
|
@ -39,7 +35,7 @@ from torch.distributed.fsdp.flat_param import (
|
||||||
HandleShardingStrategy,
|
HandleShardingStrategy,
|
||||||
HandleTrainingState,
|
HandleTrainingState,
|
||||||
)
|
)
|
||||||
from torch.distributed.utils import _to_kwargs
|
from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
|
||||||
|
|
||||||
RESHARD_AFTER_FORWARD_STRATEGIES = {
|
RESHARD_AFTER_FORWARD_STRATEGIES = {
|
||||||
HandleShardingStrategy.FULL_SHARD,
|
HandleShardingStrategy.FULL_SHARD,
|
||||||
|
|
@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs(
|
||||||
attr_name_to_values[attr_name] = set()
|
attr_name_to_values[attr_name] = set()
|
||||||
for fsdp_state in traversal_utils._get_fsdp_states(root_module):
|
for fsdp_state in traversal_utils._get_fsdp_states(root_module):
|
||||||
for attr_name in HOMOGENEOUS_ATTR_NAMES:
|
for attr_name in HOMOGENEOUS_ATTR_NAMES:
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(fsdp_state, attr_name),
|
hasattr(fsdp_state, attr_name),
|
||||||
f"FSDP state missing attribute {attr_name}",
|
f"FSDP state missing attribute {attr_name}",
|
||||||
)
|
)
|
||||||
|
|
@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs(
|
||||||
# Relax the assert for non-root FSDP instances in case the nested
|
# Relax the assert for non-root FSDP instances in case the nested
|
||||||
# initialized module is wrapped again in FSDP later (e.g. after
|
# initialized module is wrapped again in FSDP later (e.g. after
|
||||||
# training to run inference)
|
# training to run inference)
|
||||||
p_assert(
|
_p_assert(
|
||||||
fsdp_state._is_root is None or not fsdp_state._is_root,
|
fsdp_state._is_root is None or not fsdp_state._is_root,
|
||||||
"Non-root FSDP instance's `_is_root` should not have been "
|
"Non-root FSDP instance's `_is_root` should not have been "
|
||||||
"set yet or should have been set to `False`",
|
"set yet or should have been set to `False`",
|
||||||
|
|
@ -344,7 +340,7 @@ def _reshard(
|
||||||
"""
|
"""
|
||||||
if not handles:
|
if not handles:
|
||||||
return
|
return
|
||||||
p_assert(
|
_p_assert(
|
||||||
len(handles) == len(free_unsharded_flat_params),
|
len(handles) == len(free_unsharded_flat_params),
|
||||||
"Expects both lists to have equal length but got "
|
"Expects both lists to have equal length but got "
|
||||||
f"{len(handles)} and {len(free_unsharded_flat_params)}",
|
f"{len(handles)} and {len(free_unsharded_flat_params)}",
|
||||||
|
|
@ -518,7 +514,7 @@ def _root_pre_forward(
|
||||||
may not be the root. If not, then this method does not do anything.
|
may not be the root. If not, then this method does not do anything.
|
||||||
"""
|
"""
|
||||||
_lazy_init(state, module)
|
_lazy_init(state, module)
|
||||||
p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
|
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
|
||||||
if not state._is_root:
|
if not state._is_root:
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
if state.forward_prefetch:
|
if state.forward_prefetch:
|
||||||
|
|
@ -675,7 +671,7 @@ def _post_backward_hook(
|
||||||
# the same `FlatParameter`, the post-backward hook may run multiple
|
# the same `FlatParameter`, the post-backward hook may run multiple
|
||||||
# times in one backward, in which case we permit the state to already
|
# times in one backward, in which case we permit the state to already
|
||||||
# be in `BACKWARD_POST`.
|
# be in `BACKWARD_POST`.
|
||||||
p_assert(
|
_p_assert(
|
||||||
handle._training_state
|
handle._training_state
|
||||||
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
|
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
|
||||||
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
|
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
|
||||||
|
|
@ -855,8 +851,8 @@ def _check_comm_hook(
|
||||||
comm_hook: Any,
|
comm_hook: Any,
|
||||||
comm_hook_state: Any,
|
comm_hook_state: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
p_assert(comm_hook is not None, "Communication hook should not be `None`")
|
_p_assert(comm_hook is not None, "Communication hook should not be `None`")
|
||||||
p_assert(
|
_p_assert(
|
||||||
comm_hook_state is not None, "Communication hook state should not be `None`"
|
comm_hook_state is not None, "Communication hook state should not be `None`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -865,13 +861,13 @@ def _check_grad_to_accumulate(
|
||||||
new_sharded_grad: torch.Tensor,
|
new_sharded_grad: torch.Tensor,
|
||||||
accumulated_grad: torch.Tensor,
|
accumulated_grad: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
p_assert(
|
_p_assert(
|
||||||
accumulated_grad.shape == new_sharded_grad.shape,
|
accumulated_grad.shape == new_sharded_grad.shape,
|
||||||
"Shape mismatch when accumulating gradients: "
|
"Shape mismatch when accumulating gradients: "
|
||||||
f"existing gradient shape={accumulated_grad.shape} "
|
f"existing gradient shape={accumulated_grad.shape} "
|
||||||
f"new gradient shape={new_sharded_grad.shape}",
|
f"new gradient shape={new_sharded_grad.shape}",
|
||||||
)
|
)
|
||||||
p_assert(
|
_p_assert(
|
||||||
accumulated_grad.device == new_sharded_grad.device,
|
accumulated_grad.device == new_sharded_grad.device,
|
||||||
"Device mismatch when accumulating gradients: "
|
"Device mismatch when accumulating gradients: "
|
||||||
f"existing gradient device={accumulated_grad.device} "
|
f"existing gradient device={accumulated_grad.device} "
|
||||||
|
|
@ -895,7 +891,7 @@ def _post_backward_final_callback(
|
||||||
This runs at the end of the entire backward pass and should only be called
|
This runs at the end of the entire backward pass and should only be called
|
||||||
on the root FSDP instance.
|
on the root FSDP instance.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
state._is_root,
|
state._is_root,
|
||||||
"The post-backward callback should only be called on the root FSDP instance",
|
"The post-backward callback should only be called on the root FSDP instance",
|
||||||
)
|
)
|
||||||
|
|
@ -952,7 +948,7 @@ def _catch_all_reshard(
|
||||||
if handles_to_reshard:
|
if handles_to_reshard:
|
||||||
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
|
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
p_assert(
|
_p_assert(
|
||||||
False,
|
False,
|
||||||
f"Got exception in the catch-all reshard for {state}: {str(e)}",
|
f"Got exception in the catch-all reshard for {state}: {str(e)}",
|
||||||
raise_assertion_error=False,
|
raise_assertion_error=False,
|
||||||
|
|
@ -969,7 +965,7 @@ def _finalize_params(
|
||||||
flat_param = handle.flat_param
|
flat_param = handle.flat_param
|
||||||
if flat_param.requires_grad:
|
if flat_param.requires_grad:
|
||||||
if hasattr(flat_param, "_post_backward_hook_state"):
|
if hasattr(flat_param, "_post_backward_hook_state"):
|
||||||
p_assert(
|
_p_assert(
|
||||||
len(flat_param._post_backward_hook_state) == 2,
|
len(flat_param._post_backward_hook_state) == 2,
|
||||||
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
|
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
|
||||||
)
|
)
|
||||||
|
|
@ -982,7 +978,7 @@ def _finalize_params(
|
||||||
# sharded gradient from the last synchronized iteration
|
# sharded gradient from the last synchronized iteration
|
||||||
continue
|
continue
|
||||||
handle.prepare_gradient_for_optim()
|
handle.prepare_gradient_for_optim()
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(flat_param, "_post_backward_called"),
|
hasattr(flat_param, "_post_backward_called"),
|
||||||
"Expects `_post_backward_called` to be set on the `FlatParameter`",
|
"Expects `_post_backward_called` to be set on the `FlatParameter`",
|
||||||
)
|
)
|
||||||
|
|
@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch(
|
||||||
HandleTrainingState.BACKWARD_POST,
|
HandleTrainingState.BACKWARD_POST,
|
||||||
HandleTrainingState.FORWARD,
|
HandleTrainingState.FORWARD,
|
||||||
)
|
)
|
||||||
p_assert(
|
_p_assert(
|
||||||
training_state in valid_training_states,
|
training_state in valid_training_states,
|
||||||
f"Prefetching is only supported in {valid_training_states} but "
|
f"Prefetching is only supported in {valid_training_states} but "
|
||||||
f"currently in {training_state}",
|
f"currently in {training_state}",
|
||||||
|
|
@ -1067,9 +1063,9 @@ def _get_training_state(
|
||||||
handles_key: _HandlesKey,
|
handles_key: _HandlesKey,
|
||||||
) -> HandleTrainingState:
|
) -> HandleTrainingState:
|
||||||
"""Returns the training state of the handles in ``handles_key``."""
|
"""Returns the training state of the handles in ``handles_key``."""
|
||||||
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
|
_p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
|
||||||
training_states = {handle._training_state for handle in handles_key}
|
training_states = {handle._training_state for handle in handles_key}
|
||||||
p_assert(
|
_p_assert(
|
||||||
len(training_states) == 1,
|
len(training_states) == 1,
|
||||||
f"Expects uniform training state but got {training_states}",
|
f"Expects uniform training state but got {training_states}",
|
||||||
)
|
)
|
||||||
|
|
@ -1233,7 +1229,7 @@ def _register_post_backward_hooks(
|
||||||
continue
|
continue
|
||||||
# Get the `AccumulateGrad` object
|
# Get the `AccumulateGrad` object
|
||||||
temp_flat_param = flat_param.expand_as(flat_param)
|
temp_flat_param = flat_param.expand_as(flat_param)
|
||||||
p_assert(
|
_p_assert(
|
||||||
temp_flat_param.grad_fn is not None,
|
temp_flat_param.grad_fn is not None,
|
||||||
"The `grad_fn` is needed to access the `AccumulateGrad` and "
|
"The `grad_fn` is needed to access the `AccumulateGrad` and "
|
||||||
"register the post-backward hook",
|
"register the post-backward hook",
|
||||||
|
|
@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback(
|
||||||
backward pass. This should be called from the root FSDP instance at the
|
backward pass. This should be called from the root FSDP instance at the
|
||||||
beginning of the pre-backward.
|
beginning of the pre-backward.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
state._is_root,
|
state._is_root,
|
||||||
"Only the root FSDP instance should register the post-backward callback",
|
"Only the root FSDP instance should register the post-backward callback",
|
||||||
)
|
)
|
||||||
|
|
@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation(
|
||||||
is either ``None`` if buffer mixed precision is not enabled or the buffer
|
is either ``None`` if buffer mixed precision is not enabled or the buffer
|
||||||
low precision dtype otherwise.
|
low precision dtype otherwise.
|
||||||
"""
|
"""
|
||||||
p_assert(state._is_root, "Expects the root to cast buffers")
|
_p_assert(state._is_root, "Expects the root to cast buffers")
|
||||||
buffers: List[torch.Tensor] = []
|
buffers: List[torch.Tensor] = []
|
||||||
buffer_dtypes: List[Optional[torch.dtype]] = []
|
buffer_dtypes: List[Optional[torch.dtype]] = []
|
||||||
if _is_composable(state):
|
if _is_composable(state):
|
||||||
|
|
@ -1344,7 +1340,7 @@ def _get_buffer_dtypes(
|
||||||
"""
|
"""
|
||||||
buffer_dtypes: List[torch.dtype] = []
|
buffer_dtypes: List[torch.dtype] = []
|
||||||
for buffer_name in buffer_names:
|
for buffer_name in buffer_names:
|
||||||
p_assert(
|
_p_assert(
|
||||||
buffer_name in state._buffer_name_to_orig_dtype,
|
buffer_name in state._buffer_name_to_orig_dtype,
|
||||||
f"{buffer_name} is missing from pre-computed dict on rank "
|
f"{buffer_name} is missing from pre-computed dict on rank "
|
||||||
f"{state.rank}, which only has keys "
|
f"{state.rank}, which only has keys "
|
||||||
|
|
@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device(
|
||||||
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
|
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
|
||||||
corresponding buffer is only moved to ``device``.
|
corresponding buffer is only moved to ``device``.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
|
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
|
||||||
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
|
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
|
||||||
f"`buffer_dtypes` is specified but got {len(buffers)} and "
|
f"`buffer_dtypes` is specified but got {len(buffers)} and "
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from torch.distributed.fsdp._runtime_utils import (
|
||||||
_unshard,
|
_unshard,
|
||||||
_unshard_grads,
|
_unshard_grads,
|
||||||
)
|
)
|
||||||
from ._utils import p_assert
|
from torch.distributed.utils import _p_assert
|
||||||
from .flat_param import FlatParamHandle
|
from .flat_param import FlatParamHandle
|
||||||
|
|
||||||
FLAT_PARAM = "_flat_param"
|
FLAT_PARAM = "_flat_param"
|
||||||
|
|
@ -336,7 +336,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
||||||
Deregisters the original parameters; registers the ``FlatParameter``.
|
Deregisters the original parameters; registers the ``FlatParameter``.
|
||||||
"""
|
"""
|
||||||
handles = _module_handles(state, module)
|
handles = _module_handles(state, module)
|
||||||
p_assert(
|
_p_assert(
|
||||||
len(handles) <= 1,
|
len(handles) <= 1,
|
||||||
"Expects <=1 handle per FSDP instance; needs to be refactored "
|
"Expects <=1 handle per FSDP instance; needs to be refactored "
|
||||||
"for >1 handle (e.g. non-recursive wrapping)",
|
"for >1 handle (e.g. non-recursive wrapping)",
|
||||||
|
|
@ -344,7 +344,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
||||||
if not handles:
|
if not handles:
|
||||||
return
|
return
|
||||||
handle = handles[0]
|
handle = handles[0]
|
||||||
p_assert(
|
_p_assert(
|
||||||
handle._use_orig_params,
|
handle._use_orig_params,
|
||||||
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
|
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
|
||||||
f"handle: {handle._use_orig_params}",
|
f"handle: {handle._use_orig_params}",
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,7 @@
|
||||||
import dataclasses
|
from typing import cast
|
||||||
import traceback
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
|
|
||||||
_is_namedtuple,
|
|
||||||
)
|
|
||||||
from torch.nn.utils.rnn import PackedSequence
|
|
||||||
from torch.utils._mode_utils import no_dispatch
|
from torch.utils._mode_utils import no_dispatch
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module):
|
||||||
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
|
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def _apply_to_tensors(
|
|
||||||
fn: Callable,
|
|
||||||
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
|
|
||||||
) -> Any:
|
|
||||||
"""Recursively apply to all tensor in different kinds of container types."""
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
|
|
||||||
) -> Any:
|
|
||||||
if torch.is_tensor(x):
|
|
||||||
return fn(x)
|
|
||||||
elif hasattr(x, "__dataclass_fields__"):
|
|
||||||
dc = dataclasses.replace(x)
|
|
||||||
for f in dataclasses.fields(dc):
|
|
||||||
name = f.name
|
|
||||||
setattr(dc, name, apply(getattr(dc, name)))
|
|
||||||
return dc
|
|
||||||
elif isinstance(x, OrderedDict):
|
|
||||||
od = x.__class__()
|
|
||||||
for key, value in x.items():
|
|
||||||
od[key] = apply(value)
|
|
||||||
return od
|
|
||||||
elif isinstance(x, PackedSequence):
|
|
||||||
apply(x.data)
|
|
||||||
return x
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
return {key: apply(value) for key, value in x.items()}
|
|
||||||
elif _is_namedtuple(x):
|
|
||||||
res = (apply(el) for el in x)
|
|
||||||
return type(x)(*res)
|
|
||||||
elif isinstance(x, (list, tuple, set)):
|
|
||||||
return type(x)(apply(el) for el in x)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
return apply(container)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
|
|
||||||
"""
|
|
||||||
Allocate storage for ``tensor`` with the given size.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: ``True`` if this method allocated storage and ``False`` if the
|
|
||||||
storage was already allocated.
|
|
||||||
"""
|
|
||||||
already_allocated = tensor._typed_storage()._size() == size.numel()
|
|
||||||
if not already_allocated:
|
|
||||||
tensor_storage_size = tensor._typed_storage()._size()
|
|
||||||
p_assert(
|
|
||||||
tensor_storage_size == 0,
|
|
||||||
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
|
|
||||||
)
|
|
||||||
tensor._typed_storage()._resize_(size.numel())
|
|
||||||
return not already_allocated
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _free_storage(tensor: torch.Tensor) -> bool:
|
|
||||||
"""
|
|
||||||
Frees the underlying storage of ``tensor``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: ``True`` if the method freed the storage and ``False`` if the
|
|
||||||
storage was already freed.
|
|
||||||
"""
|
|
||||||
already_freed = tensor._typed_storage()._size() == 0
|
|
||||||
if not already_freed:
|
|
||||||
p_assert(
|
|
||||||
tensor.storage_offset() == 0,
|
|
||||||
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
|
||||||
f"storage offset: {tensor.storage_offset()}\n"
|
|
||||||
f"storage size: {tensor._typed_storage()._size()}\n"
|
|
||||||
f"tensor shape: {tensor.shape}",
|
|
||||||
)
|
|
||||||
tensor._typed_storage()._resize_(0)
|
|
||||||
return not already_freed
|
|
||||||
|
|
||||||
|
|
||||||
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
|
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
|
||||||
"""Returns if ``x`` and ``y`` share the same storage."""
|
"""Returns if ``x`` and ``y`` share the same storage."""
|
||||||
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
|
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
|
||||||
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
|
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
|
||||||
|
|
||||||
|
|
||||||
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
|
||||||
"""This is used as an alternate to ``assert`` when in the backward context
|
|
||||||
to print the error message ``s`` since otherwise, it is swallowed."""
|
|
||||||
if not cond:
|
|
||||||
print(s)
|
|
||||||
traceback.print_stack()
|
|
||||||
if raise_assertion_error:
|
|
||||||
raise AssertionError(s)
|
|
||||||
|
|
||||||
|
|
||||||
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
|
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
|
||||||
with no_dispatch():
|
with no_dispatch():
|
||||||
tensor.record_stream(cast(torch._C.Stream, stream))
|
tensor.record_stream(cast(torch._C.Stream, stream))
|
||||||
|
|
|
||||||
|
|
@ -27,15 +27,10 @@ from torch.distributed.fsdp._common_utils import (
|
||||||
_set_fsdp_flattened,
|
_set_fsdp_flattened,
|
||||||
HandleTrainingState,
|
HandleTrainingState,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
|
||||||
|
|
||||||
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
|
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
|
||||||
from ._utils import (
|
from ._utils import _no_dispatch_record_stream, _same_storage
|
||||||
_alloc_storage,
|
|
||||||
_free_storage,
|
|
||||||
_no_dispatch_record_stream,
|
|
||||||
_same_storage,
|
|
||||||
p_assert,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FlatParameter",
|
"FlatParameter",
|
||||||
|
|
@ -558,7 +553,7 @@ class FlatParamHandle:
|
||||||
if not self.uses_sharded_strategy:
|
if not self.uses_sharded_strategy:
|
||||||
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
|
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
|
||||||
else:
|
else:
|
||||||
p_assert(
|
_p_assert(
|
||||||
flat_param.storage_offset() == 0,
|
flat_param.storage_offset() == 0,
|
||||||
"The `FlatParameter` is not the sole occupant of its storage",
|
"The `FlatParameter` is not the sole occupant of its storage",
|
||||||
)
|
)
|
||||||
|
|
@ -600,8 +595,8 @@ class FlatParamHandle:
|
||||||
"""
|
"""
|
||||||
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
|
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
|
||||||
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
|
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
|
||||||
p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
|
_p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
|
||||||
p_assert(
|
_p_assert(
|
||||||
numel_padded <= sharded_flat_param_numel,
|
numel_padded <= sharded_flat_param_numel,
|
||||||
f"numel_padded: {numel_padded} "
|
f"numel_padded: {numel_padded} "
|
||||||
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
|
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
|
||||||
|
|
@ -792,7 +787,7 @@ class FlatParamHandle:
|
||||||
self._orig_param_dtype = flat_param.dtype
|
self._orig_param_dtype = flat_param.dtype
|
||||||
cpu_device = torch.device("cpu")
|
cpu_device = torch.device("cpu")
|
||||||
if self._offload_params:
|
if self._offload_params:
|
||||||
p_assert(
|
_p_assert(
|
||||||
flat_param.device == cpu_device,
|
flat_param.device == cpu_device,
|
||||||
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
|
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
|
||||||
f"offloading is enabled, not {flat_param.device}",
|
f"offloading is enabled, not {flat_param.device}",
|
||||||
|
|
@ -957,7 +952,7 @@ class FlatParamHandle:
|
||||||
# tensor as the all-gather destination to preserve the invariant
|
# tensor as the all-gather destination to preserve the invariant
|
||||||
# that `_full_param_padded` is in the low precision
|
# that `_full_param_padded` is in the low precision
|
||||||
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
|
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
|
||||||
p_assert(
|
_p_assert(
|
||||||
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
|
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
|
||||||
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
|
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
|
||||||
)
|
)
|
||||||
|
|
@ -974,13 +969,13 @@ class FlatParamHandle:
|
||||||
``padded_unsharded_flat_param``, and switches to using the all-gathered
|
``padded_unsharded_flat_param``, and switches to using the all-gathered
|
||||||
tensor.
|
tensor.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
||||||
"Expects a process group and world size to have been set via `shard()`",
|
"Expects a process group and world size to have been set via `shard()`",
|
||||||
)
|
)
|
||||||
sharded_flat_param = self.flat_param.data
|
sharded_flat_param = self.flat_param.data
|
||||||
expected_numel = sharded_flat_param.numel() * self.world_size
|
expected_numel = sharded_flat_param.numel() * self.world_size
|
||||||
p_assert(
|
_p_assert(
|
||||||
padded_unsharded_flat_param.numel() == expected_numel,
|
padded_unsharded_flat_param.numel() == expected_numel,
|
||||||
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
||||||
)
|
)
|
||||||
|
|
@ -1111,7 +1106,7 @@ class FlatParamHandle:
|
||||||
clearing any existing sharded gradient in ``.grad`` to enable computing
|
clearing any existing sharded gradient in ``.grad`` to enable computing
|
||||||
a new unsharded gradient.
|
a new unsharded gradient.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
self._training_state
|
self._training_state
|
||||||
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
|
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
|
||||||
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
|
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
|
||||||
|
|
@ -1123,7 +1118,7 @@ class FlatParamHandle:
|
||||||
):
|
):
|
||||||
self._check_on_compute_device(self.flat_param)
|
self._check_on_compute_device(self.flat_param)
|
||||||
grad_offloaded = flat_param.grad.device != self.device
|
grad_offloaded = flat_param.grad.device != self.device
|
||||||
p_assert(
|
_p_assert(
|
||||||
not grad_offloaded or self._offload_params,
|
not grad_offloaded or self._offload_params,
|
||||||
f"Expects the sharded gradient to be on {self.device} "
|
f"Expects the sharded gradient to be on {self.device} "
|
||||||
f"but got {flat_param.grad.device}",
|
f"but got {flat_param.grad.device}",
|
||||||
|
|
@ -1142,7 +1137,7 @@ class FlatParamHandle:
|
||||||
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
||||||
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
||||||
else:
|
else:
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(flat_param, "_cpu_grad"),
|
hasattr(flat_param, "_cpu_grad"),
|
||||||
"`_cpu_grad` should be defined if the gradient is on CPU",
|
"`_cpu_grad` should be defined if the gradient is on CPU",
|
||||||
)
|
)
|
||||||
|
|
@ -1162,7 +1157,7 @@ class FlatParamHandle:
|
||||||
sharded_grad.data = sharded_grad.to(local_shard_dtype)
|
sharded_grad.data = sharded_grad.to(local_shard_dtype)
|
||||||
else:
|
else:
|
||||||
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
||||||
p_assert(
|
_p_assert(
|
||||||
flat_param.grad.size() == padded_unsharded_size,
|
flat_param.grad.size() == padded_unsharded_size,
|
||||||
"Expects `.grad` to be the unsharded gradient in "
|
"Expects `.grad` to be the unsharded gradient in "
|
||||||
f"`no_sync()` with size {padded_unsharded_size} "
|
f"`no_sync()` with size {padded_unsharded_size} "
|
||||||
|
|
@ -1203,7 +1198,7 @@ class FlatParamHandle:
|
||||||
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
||||||
cast_grad_to_param_dtype_if_needed(flat_param)
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
||||||
else:
|
else:
|
||||||
p_assert(
|
_p_assert(
|
||||||
not self.uses_sharded_strategy
|
not self.uses_sharded_strategy
|
||||||
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
||||||
"All sharded parameters that received a gradient in the "
|
"All sharded parameters that received a gradient in the "
|
||||||
|
|
@ -1229,7 +1224,7 @@ class FlatParamHandle:
|
||||||
Postcondition: Same as the precondition.
|
Postcondition: Same as the precondition.
|
||||||
"""
|
"""
|
||||||
self._check_sharded_strategy()
|
self._check_sharded_strategy()
|
||||||
p_assert(
|
_p_assert(
|
||||||
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
||||||
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
||||||
)
|
)
|
||||||
|
|
@ -1242,7 +1237,7 @@ class FlatParamHandle:
|
||||||
padded_storage_ptr = (
|
padded_storage_ptr = (
|
||||||
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
|
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
|
||||||
)
|
)
|
||||||
p_assert(
|
_p_assert(
|
||||||
unpadded_storage_ptr == padded_storage_ptr,
|
unpadded_storage_ptr == padded_storage_ptr,
|
||||||
"Expects the unpadded parameter to be a view into the padded parameter",
|
"Expects the unpadded parameter to be a view into the padded parameter",
|
||||||
)
|
)
|
||||||
|
|
@ -1251,7 +1246,7 @@ class FlatParamHandle:
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
p_assert(
|
_p_assert(
|
||||||
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
||||||
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
||||||
)
|
)
|
||||||
|
|
@ -1314,7 +1309,7 @@ class FlatParamHandle:
|
||||||
flat_param = self.flat_param
|
flat_param = self.flat_param
|
||||||
if self._offload_params:
|
if self._offload_params:
|
||||||
device = flat_param._local_shard.device # type: ignore[attr-defined]
|
device = flat_param._local_shard.device # type: ignore[attr-defined]
|
||||||
p_assert(
|
_p_assert(
|
||||||
device == torch.device("cpu"),
|
device == torch.device("cpu"),
|
||||||
f"Expects the local shard to be on CPU but got {device}",
|
f"Expects the local shard to be on CPU but got {device}",
|
||||||
)
|
)
|
||||||
|
|
@ -1357,7 +1352,7 @@ class FlatParamHandle:
|
||||||
"""
|
"""
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
tensor = flat_param
|
tensor = flat_param
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
|
tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
|
||||||
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
|
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
|
||||||
f"{tensor.numel()} numel",
|
f"{tensor.numel()} numel",
|
||||||
|
|
@ -1416,7 +1411,7 @@ class FlatParamHandle:
|
||||||
# hook fires (e.g. for reentrant AC)
|
# hook fires (e.g. for reentrant AC)
|
||||||
assert self.flat_param._tensors is not None # mypy
|
assert self.flat_param._tensors is not None # mypy
|
||||||
tensor = self.flat_param._tensors[i]
|
tensor = self.flat_param._tensors[i]
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor is not None,
|
tensor is not None,
|
||||||
"Expects `Tensor` to have been saved in forward",
|
"Expects `Tensor` to have been saved in forward",
|
||||||
)
|
)
|
||||||
|
|
@ -1439,14 +1434,14 @@ class FlatParamHandle:
|
||||||
) in enumerate(self.flat_param._shared_param_infos):
|
) in enumerate(self.flat_param._shared_param_infos):
|
||||||
if hasattr(module, param_name):
|
if hasattr(module, param_name):
|
||||||
delattr(module, param_name)
|
delattr(module, param_name)
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(prim_module, prim_param_name),
|
hasattr(prim_module, prim_param_name),
|
||||||
f"Module {prim_module_name} is missing parameter {prim_param_name}",
|
f"Module {prim_module_name} is missing parameter {prim_param_name}",
|
||||||
)
|
)
|
||||||
prim_param: Union[Tensor, nn.Parameter] = getattr(
|
prim_param: Union[Tensor, nn.Parameter] = getattr(
|
||||||
prim_module, prim_param_name
|
prim_module, prim_param_name
|
||||||
)
|
)
|
||||||
p_assert(
|
_p_assert(
|
||||||
not as_params or isinstance(prim_param, nn.Parameter),
|
not as_params or isinstance(prim_param, nn.Parameter),
|
||||||
f"as_params={as_params} type(prim_param)={type(prim_param)}",
|
f"as_params={as_params} type(prim_param)={type(prim_param)}",
|
||||||
)
|
)
|
||||||
|
|
@ -1485,7 +1480,7 @@ class FlatParamHandle:
|
||||||
for i, (view, (param_name, module, _)) in enumerate(
|
for i, (view, (param_name, module, _)) in enumerate(
|
||||||
zip(views, self.flat_param._param_infos)
|
zip(views, self.flat_param._param_infos)
|
||||||
):
|
):
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(module, param_name),
|
hasattr(module, param_name),
|
||||||
f"{self.flat_param._fqns[i]} is missing",
|
f"{self.flat_param._fqns[i]} is missing",
|
||||||
)
|
)
|
||||||
|
|
@ -1511,7 +1506,7 @@ class FlatParamHandle:
|
||||||
prim_module,
|
prim_module,
|
||||||
_,
|
_,
|
||||||
) in enumerate(self.flat_param._shared_param_infos):
|
) in enumerate(self.flat_param._shared_param_infos):
|
||||||
p_assert(
|
_p_assert(
|
||||||
hasattr(module, param_name),
|
hasattr(module, param_name),
|
||||||
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
||||||
) # did not save FQN info in `_shared_param_infos`
|
) # did not save FQN info in `_shared_param_infos`
|
||||||
|
|
@ -1793,7 +1788,7 @@ class FlatParamHandle:
|
||||||
RuntimeError: If the ``src_tensor`` does not have the expected
|
RuntimeError: If the ``src_tensor`` does not have the expected
|
||||||
shape.
|
shape.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
len(expected_shape) == 1,
|
len(expected_shape) == 1,
|
||||||
f"Expects a 1D expected shape but got {expected_shape}",
|
f"Expects a 1D expected shape but got {expected_shape}",
|
||||||
)
|
)
|
||||||
|
|
@ -1935,7 +1930,7 @@ class FlatParamHandle:
|
||||||
else:
|
else:
|
||||||
# If in the forward, then there may be an accumulated gradient,
|
# If in the forward, then there may be an accumulated gradient,
|
||||||
# which will be in `.grad`
|
# which will be in `.grad`
|
||||||
p_assert(
|
_p_assert(
|
||||||
flat_param.grad is None
|
flat_param.grad is None
|
||||||
or not self.uses_sharded_strategy
|
or not self.uses_sharded_strategy
|
||||||
or self._training_state == HandleTrainingState.FORWARD,
|
or self._training_state == HandleTrainingState.FORWARD,
|
||||||
|
|
@ -1954,7 +1949,7 @@ class FlatParamHandle:
|
||||||
"""
|
"""
|
||||||
if not self._use_orig_params:
|
if not self._use_orig_params:
|
||||||
return
|
return
|
||||||
p_assert(
|
_p_assert(
|
||||||
self._training_state == HandleTrainingState.BACKWARD_POST,
|
self._training_state == HandleTrainingState.BACKWARD_POST,
|
||||||
"Expects to only be called in the post-backward after gradient computation",
|
"Expects to only be called in the post-backward after gradient computation",
|
||||||
)
|
)
|
||||||
|
|
@ -1971,16 +1966,16 @@ class FlatParamHandle:
|
||||||
# CHECKS & INVARIANTS #
|
# CHECKS & INVARIANTS #
|
||||||
#######################
|
#######################
|
||||||
def _check_sharded_strategy(self):
|
def _check_sharded_strategy(self):
|
||||||
p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
||||||
|
|
||||||
def _check_on_compute_device(self, tensor: Tensor):
|
def _check_on_compute_device(self, tensor: Tensor):
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor.device == self.device,
|
tensor.device == self.device,
|
||||||
f"Expects tensor to be on the compute device {self.device}",
|
f"Expects tensor to be on the compute device {self.device}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_on_cpu(self, tensor: Tensor):
|
def _check_on_cpu(self, tensor: Tensor):
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor.device == torch.device("cpu"),
|
tensor.device == torch.device("cpu"),
|
||||||
f"Expects tensor to be on CPU but got {tensor.device}",
|
f"Expects tensor to be on CPU but got {tensor.device}",
|
||||||
)
|
)
|
||||||
|
|
@ -1988,7 +1983,7 @@ class FlatParamHandle:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_storage_freed(tensor: Tensor):
|
def _check_storage_freed(tensor: Tensor):
|
||||||
storage_size: int = tensor._typed_storage()._size()
|
storage_size: int = tensor._typed_storage()._size()
|
||||||
p_assert(
|
_p_assert(
|
||||||
storage_size == 0,
|
storage_size == 0,
|
||||||
f"Expects storage to be freed but got storage with size {storage_size}",
|
f"Expects storage to be freed but got storage with size {storage_size}",
|
||||||
)
|
)
|
||||||
|
|
@ -1996,37 +1991,37 @@ class FlatParamHandle:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_storage_allocated(tensor: Tensor):
|
def _check_storage_allocated(tensor: Tensor):
|
||||||
storage_size: int = tensor._typed_storage()._size()
|
storage_size: int = tensor._typed_storage()._size()
|
||||||
p_assert(storage_size > 0, "Expects storage to be allocated")
|
_p_assert(storage_size > 0, "Expects storage to be allocated")
|
||||||
|
|
||||||
def _check_low_precision_shard(self):
|
def _check_low_precision_shard(self):
|
||||||
p_assert(
|
_p_assert(
|
||||||
self._uses_param_mixed_precision,
|
self._uses_param_mixed_precision,
|
||||||
"Not using low precision for parameters",
|
"Not using low precision for parameters",
|
||||||
)
|
)
|
||||||
p_assert(
|
_p_assert(
|
||||||
getattr(self.flat_param, "_mp_shard", None) is not None,
|
getattr(self.flat_param, "_mp_shard", None) is not None,
|
||||||
"Expects `_mp_shard` to exist",
|
"Expects `_mp_shard` to exist",
|
||||||
)
|
)
|
||||||
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
|
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
|
||||||
p_assert(
|
_p_assert(
|
||||||
device == self.device,
|
device == self.device,
|
||||||
f"Expects the low precision shard to be on {self.device} but got {device}",
|
f"Expects the low precision shard to be on {self.device} but got {device}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_unsharded(self, tensor: Tensor):
|
def _check_unsharded(self, tensor: Tensor):
|
||||||
msg_prefix = "Expects tensor to be unsharded "
|
msg_prefix = "Expects tensor to be unsharded "
|
||||||
p_assert(tensor is not None, msg_prefix + "but got `None`")
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||||
unsharded_size = self.flat_param._unpadded_unsharded_size
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor.size() == unsharded_size,
|
tensor.size() == unsharded_size,
|
||||||
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_sharded(self, tensor: Tensor):
|
def _check_sharded(self, tensor: Tensor):
|
||||||
msg_prefix = "Expects tensor to be sharded "
|
msg_prefix = "Expects tensor to be sharded "
|
||||||
p_assert(tensor is not None, msg_prefix + "but got `None`")
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||||
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
||||||
p_assert(
|
_p_assert(
|
||||||
tensor.size() == sharded_size,
|
tensor.size() == sharded_size,
|
||||||
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
|
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ from torch.distributed.fsdp.api import (
|
||||||
StateDictSettings,
|
StateDictSettings,
|
||||||
StateDictType,
|
StateDictType,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.utils import _p_assert
|
||||||
|
|
||||||
from ._optim_utils import (
|
from ._optim_utils import (
|
||||||
_broadcast_pos_dim_tensor_states,
|
_broadcast_pos_dim_tensor_states,
|
||||||
|
|
@ -98,7 +99,6 @@ from ._unshard_param_utils import (
|
||||||
_unshard_params,
|
_unshard_params,
|
||||||
_unshard_params_recurse,
|
_unshard_params_recurse,
|
||||||
)
|
)
|
||||||
from ._utils import p_assert
|
|
||||||
from .flat_param import FlatParameter
|
from .flat_param import FlatParameter
|
||||||
from .wrap import _FSDPPolicy
|
from .wrap import _FSDPPolicy
|
||||||
|
|
||||||
|
|
@ -740,7 +740,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
|
self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
|
||||||
)
|
)
|
||||||
for handle in self._handles:
|
for handle in self._handles:
|
||||||
p_assert(
|
_p_assert(
|
||||||
handle.flat_param.device == self.compute_device,
|
handle.flat_param.device == self.compute_device,
|
||||||
"Expected `FlatParameter` to be on the compute device "
|
"Expected `FlatParameter` to be on the compute device "
|
||||||
f"{self.compute_device} but got {handle.flat_param.device}",
|
f"{self.compute_device} but got {handle.flat_param.device}",
|
||||||
|
|
@ -830,7 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
this refreshes the sharded views before exiting. This method shouuld
|
this refreshes the sharded views before exiting. This method shouuld
|
||||||
only be called when using the original parameters.
|
only be called when using the original parameters.
|
||||||
"""
|
"""
|
||||||
p_assert(
|
_p_assert(
|
||||||
self._use_orig_params,
|
self._use_orig_params,
|
||||||
"`_deregister_orig_params_ctx()` should only be called when "
|
"`_deregister_orig_params_ctx()` should only be called when "
|
||||||
"`_use_orig_params=True`",
|
"`_use_orig_params=True`",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple, Callable, Union, Set, OrderedDict
|
||||||
|
import dataclasses
|
||||||
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -94,6 +96,92 @@ def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
|
||||||
to_map = None # type: ignore[assignment]
|
to_map = None # type: ignore[assignment]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
||||||
|
"""This is used as an alternate to ``assert`` when in the backward context
|
||||||
|
to print the error message ``s`` since otherwise, it is swallowed."""
|
||||||
|
if not cond:
|
||||||
|
print(s)
|
||||||
|
traceback.print_stack()
|
||||||
|
if raise_assertion_error:
|
||||||
|
raise AssertionError(s)
|
||||||
|
|
||||||
|
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
|
||||||
|
"""
|
||||||
|
Allocate storage for ``tensor`` with the given size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if this method allocated storage and ``False`` if the
|
||||||
|
storage was already allocated.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
already_allocated = tensor._typed_storage()._size() == size.numel()
|
||||||
|
if not already_allocated:
|
||||||
|
tensor_storage_size = tensor._typed_storage()._size()
|
||||||
|
_p_assert(
|
||||||
|
tensor_storage_size == 0,
|
||||||
|
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
|
||||||
|
)
|
||||||
|
tensor._typed_storage()._resize_(size.numel())
|
||||||
|
return not already_allocated
|
||||||
|
|
||||||
|
|
||||||
|
def _free_storage(tensor: torch.Tensor) -> bool:
|
||||||
|
"""
|
||||||
|
Frees the underlying storage of ``tensor``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if the method freed the storage and ``False`` if the
|
||||||
|
storage was already freed.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
already_freed = tensor._typed_storage()._size() == 0
|
||||||
|
if not already_freed:
|
||||||
|
_p_assert(
|
||||||
|
tensor.storage_offset() == 0,
|
||||||
|
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
||||||
|
f"storage offset: {tensor.storage_offset()}\n"
|
||||||
|
f"storage size: {tensor._typed_storage()._size()}\n"
|
||||||
|
f"tensor shape: {tensor.shape}",
|
||||||
|
)
|
||||||
|
tensor._typed_storage()._resize_(0)
|
||||||
|
return not already_freed
|
||||||
|
|
||||||
|
def _apply_to_tensors(
|
||||||
|
fn: Callable,
|
||||||
|
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
|
||||||
|
) -> Any:
|
||||||
|
"""Recursively apply to all tensor in different kinds of container types."""
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
|
||||||
|
) -> Any:
|
||||||
|
if torch.is_tensor(x):
|
||||||
|
return fn(x)
|
||||||
|
elif hasattr(x, "__dataclass_fields__"):
|
||||||
|
dc = dataclasses.replace(x)
|
||||||
|
for f in dataclasses.fields(dc):
|
||||||
|
name = f.name
|
||||||
|
setattr(dc, name, apply(getattr(dc, name)))
|
||||||
|
return dc
|
||||||
|
elif isinstance(x, OrderedDict):
|
||||||
|
od = x.__class__()
|
||||||
|
for key, value in x.items():
|
||||||
|
od[key] = apply(value)
|
||||||
|
return od
|
||||||
|
elif isinstance(x, PackedSequence):
|
||||||
|
apply(x.data)
|
||||||
|
return x
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
return {key: apply(value) for key, value in x.items()}
|
||||||
|
elif _is_namedtuple(x):
|
||||||
|
res = (apply(el) for el in x)
|
||||||
|
return type(x)(*res)
|
||||||
|
elif isinstance(x, (list, tuple, set)):
|
||||||
|
return type(x)(apply(el) for el in x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return apply(container)
|
||||||
|
|
||||||
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies):
|
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies):
|
||||||
inputs = (
|
inputs = (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user