mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025 Approved by: https://github.com/fegin
This commit is contained in:
parent
2aa806608b
commit
c43e88665a
|
|
@ -11,10 +11,9 @@ from typing import List
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp._utils import _apply_to_tensors
|
||||
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.utils import _replace_by_prefix
|
||||
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
|
|||
|
|
@ -26,11 +26,7 @@ from torch.distributed.fsdp._common_utils import (
|
|||
TrainingState,
|
||||
)
|
||||
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
|
||||
from torch.distributed.fsdp._utils import (
|
||||
_apply_to_tensors,
|
||||
_no_dispatch_record_stream,
|
||||
p_assert,
|
||||
)
|
||||
from torch.distributed.fsdp._utils import _no_dispatch_record_stream
|
||||
from torch.distributed.fsdp.api import BackwardPrefetch
|
||||
from torch.distributed.fsdp.flat_param import (
|
||||
_HandlesKey,
|
||||
|
|
@ -39,7 +35,7 @@ from torch.distributed.fsdp.flat_param import (
|
|||
HandleShardingStrategy,
|
||||
HandleTrainingState,
|
||||
)
|
||||
from torch.distributed.utils import _to_kwargs
|
||||
from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
|
||||
|
||||
RESHARD_AFTER_FORWARD_STRATEGIES = {
|
||||
HandleShardingStrategy.FULL_SHARD,
|
||||
|
|
@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs(
|
|||
attr_name_to_values[attr_name] = set()
|
||||
for fsdp_state in traversal_utils._get_fsdp_states(root_module):
|
||||
for attr_name in HOMOGENEOUS_ATTR_NAMES:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(fsdp_state, attr_name),
|
||||
f"FSDP state missing attribute {attr_name}",
|
||||
)
|
||||
|
|
@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs(
|
|||
# Relax the assert for non-root FSDP instances in case the nested
|
||||
# initialized module is wrapped again in FSDP later (e.g. after
|
||||
# training to run inference)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
fsdp_state._is_root is None or not fsdp_state._is_root,
|
||||
"Non-root FSDP instance's `_is_root` should not have been "
|
||||
"set yet or should have been set to `False`",
|
||||
|
|
@ -344,7 +340,7 @@ def _reshard(
|
|||
"""
|
||||
if not handles:
|
||||
return
|
||||
p_assert(
|
||||
_p_assert(
|
||||
len(handles) == len(free_unsharded_flat_params),
|
||||
"Expects both lists to have equal length but got "
|
||||
f"{len(handles)} and {len(free_unsharded_flat_params)}",
|
||||
|
|
@ -518,7 +514,7 @@ def _root_pre_forward(
|
|||
may not be the root. If not, then this method does not do anything.
|
||||
"""
|
||||
_lazy_init(state, module)
|
||||
p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
|
||||
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
|
||||
if not state._is_root:
|
||||
return args, kwargs
|
||||
if state.forward_prefetch:
|
||||
|
|
@ -675,7 +671,7 @@ def _post_backward_hook(
|
|||
# the same `FlatParameter`, the post-backward hook may run multiple
|
||||
# times in one backward, in which case we permit the state to already
|
||||
# be in `BACKWARD_POST`.
|
||||
p_assert(
|
||||
_p_assert(
|
||||
handle._training_state
|
||||
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
|
||||
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
|
||||
|
|
@ -855,8 +851,8 @@ def _check_comm_hook(
|
|||
comm_hook: Any,
|
||||
comm_hook_state: Any,
|
||||
) -> None:
|
||||
p_assert(comm_hook is not None, "Communication hook should not be `None`")
|
||||
p_assert(
|
||||
_p_assert(comm_hook is not None, "Communication hook should not be `None`")
|
||||
_p_assert(
|
||||
comm_hook_state is not None, "Communication hook state should not be `None`"
|
||||
)
|
||||
|
||||
|
|
@ -865,13 +861,13 @@ def _check_grad_to_accumulate(
|
|||
new_sharded_grad: torch.Tensor,
|
||||
accumulated_grad: torch.Tensor,
|
||||
) -> None:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
accumulated_grad.shape == new_sharded_grad.shape,
|
||||
"Shape mismatch when accumulating gradients: "
|
||||
f"existing gradient shape={accumulated_grad.shape} "
|
||||
f"new gradient shape={new_sharded_grad.shape}",
|
||||
)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
accumulated_grad.device == new_sharded_grad.device,
|
||||
"Device mismatch when accumulating gradients: "
|
||||
f"existing gradient device={accumulated_grad.device} "
|
||||
|
|
@ -895,7 +891,7 @@ def _post_backward_final_callback(
|
|||
This runs at the end of the entire backward pass and should only be called
|
||||
on the root FSDP instance.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
state._is_root,
|
||||
"The post-backward callback should only be called on the root FSDP instance",
|
||||
)
|
||||
|
|
@ -952,7 +948,7 @@ def _catch_all_reshard(
|
|||
if handles_to_reshard:
|
||||
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
|
||||
except Exception as e:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
False,
|
||||
f"Got exception in the catch-all reshard for {state}: {str(e)}",
|
||||
raise_assertion_error=False,
|
||||
|
|
@ -969,7 +965,7 @@ def _finalize_params(
|
|||
flat_param = handle.flat_param
|
||||
if flat_param.requires_grad:
|
||||
if hasattr(flat_param, "_post_backward_hook_state"):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
len(flat_param._post_backward_hook_state) == 2,
|
||||
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
|
||||
)
|
||||
|
|
@ -982,7 +978,7 @@ def _finalize_params(
|
|||
# sharded gradient from the last synchronized iteration
|
||||
continue
|
||||
handle.prepare_gradient_for_optim()
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(flat_param, "_post_backward_called"),
|
||||
"Expects `_post_backward_called` to be set on the `FlatParameter`",
|
||||
)
|
||||
|
|
@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch(
|
|||
HandleTrainingState.BACKWARD_POST,
|
||||
HandleTrainingState.FORWARD,
|
||||
)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
training_state in valid_training_states,
|
||||
f"Prefetching is only supported in {valid_training_states} but "
|
||||
f"currently in {training_state}",
|
||||
|
|
@ -1067,9 +1063,9 @@ def _get_training_state(
|
|||
handles_key: _HandlesKey,
|
||||
) -> HandleTrainingState:
|
||||
"""Returns the training state of the handles in ``handles_key``."""
|
||||
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
|
||||
_p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
|
||||
training_states = {handle._training_state for handle in handles_key}
|
||||
p_assert(
|
||||
_p_assert(
|
||||
len(training_states) == 1,
|
||||
f"Expects uniform training state but got {training_states}",
|
||||
)
|
||||
|
|
@ -1233,7 +1229,7 @@ def _register_post_backward_hooks(
|
|||
continue
|
||||
# Get the `AccumulateGrad` object
|
||||
temp_flat_param = flat_param.expand_as(flat_param)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
temp_flat_param.grad_fn is not None,
|
||||
"The `grad_fn` is needed to access the `AccumulateGrad` and "
|
||||
"register the post-backward hook",
|
||||
|
|
@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback(
|
|||
backward pass. This should be called from the root FSDP instance at the
|
||||
beginning of the pre-backward.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
state._is_root,
|
||||
"Only the root FSDP instance should register the post-backward callback",
|
||||
)
|
||||
|
|
@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation(
|
|||
is either ``None`` if buffer mixed precision is not enabled or the buffer
|
||||
low precision dtype otherwise.
|
||||
"""
|
||||
p_assert(state._is_root, "Expects the root to cast buffers")
|
||||
_p_assert(state._is_root, "Expects the root to cast buffers")
|
||||
buffers: List[torch.Tensor] = []
|
||||
buffer_dtypes: List[Optional[torch.dtype]] = []
|
||||
if _is_composable(state):
|
||||
|
|
@ -1344,7 +1340,7 @@ def _get_buffer_dtypes(
|
|||
"""
|
||||
buffer_dtypes: List[torch.dtype] = []
|
||||
for buffer_name in buffer_names:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
buffer_name in state._buffer_name_to_orig_dtype,
|
||||
f"{buffer_name} is missing from pre-computed dict on rank "
|
||||
f"{state.rank}, which only has keys "
|
||||
|
|
@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device(
|
|||
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
|
||||
corresponding buffer is only moved to ``device``.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
|
||||
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
|
||||
f"`buffer_dtypes` is specified but got {len(buffers)} and "
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from torch.distributed.fsdp._runtime_utils import (
|
|||
_unshard,
|
||||
_unshard_grads,
|
||||
)
|
||||
from ._utils import p_assert
|
||||
from torch.distributed.utils import _p_assert
|
||||
from .flat_param import FlatParamHandle
|
||||
|
||||
FLAT_PARAM = "_flat_param"
|
||||
|
|
@ -336,7 +336,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
|||
Deregisters the original parameters; registers the ``FlatParameter``.
|
||||
"""
|
||||
handles = _module_handles(state, module)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
len(handles) <= 1,
|
||||
"Expects <=1 handle per FSDP instance; needs to be refactored "
|
||||
"for >1 handle (e.g. non-recursive wrapping)",
|
||||
|
|
@ -344,7 +344,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
|||
if not handles:
|
||||
return
|
||||
handle = handles[0]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
handle._use_orig_params,
|
||||
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
|
||||
f"handle: {handle._use_orig_params}",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,7 @@
|
|||
import dataclasses
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
|
||||
_is_namedtuple,
|
||||
)
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
|
||||
|
|
@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module):
|
|||
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
|
||||
|
||||
|
||||
def _apply_to_tensors(
|
||||
fn: Callable,
|
||||
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
|
||||
) -> Any:
|
||||
"""Recursively apply to all tensor in different kinds of container types."""
|
||||
|
||||
def apply(
|
||||
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
|
||||
) -> Any:
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
elif hasattr(x, "__dataclass_fields__"):
|
||||
dc = dataclasses.replace(x)
|
||||
for f in dataclasses.fields(dc):
|
||||
name = f.name
|
||||
setattr(dc, name, apply(getattr(dc, name)))
|
||||
return dc
|
||||
elif isinstance(x, OrderedDict):
|
||||
od = x.__class__()
|
||||
for key, value in x.items():
|
||||
od[key] = apply(value)
|
||||
return od
|
||||
elif isinstance(x, PackedSequence):
|
||||
apply(x.data)
|
||||
return x
|
||||
elif isinstance(x, dict):
|
||||
return {key: apply(value) for key, value in x.items()}
|
||||
elif _is_namedtuple(x):
|
||||
res = (apply(el) for el in x)
|
||||
return type(x)(*res)
|
||||
elif isinstance(x, (list, tuple, set)):
|
||||
return type(x)(apply(el) for el in x)
|
||||
else:
|
||||
return x
|
||||
|
||||
return apply(container)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
|
||||
"""
|
||||
Allocate storage for ``tensor`` with the given size.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if this method allocated storage and ``False`` if the
|
||||
storage was already allocated.
|
||||
"""
|
||||
already_allocated = tensor._typed_storage()._size() == size.numel()
|
||||
if not already_allocated:
|
||||
tensor_storage_size = tensor._typed_storage()._size()
|
||||
p_assert(
|
||||
tensor_storage_size == 0,
|
||||
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
|
||||
)
|
||||
tensor._typed_storage()._resize_(size.numel())
|
||||
return not already_allocated
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _free_storage(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Frees the underlying storage of ``tensor``.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the method freed the storage and ``False`` if the
|
||||
storage was already freed.
|
||||
"""
|
||||
already_freed = tensor._typed_storage()._size() == 0
|
||||
if not already_freed:
|
||||
p_assert(
|
||||
tensor.storage_offset() == 0,
|
||||
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
||||
f"storage offset: {tensor.storage_offset()}\n"
|
||||
f"storage size: {tensor._typed_storage()._size()}\n"
|
||||
f"tensor shape: {tensor.shape}",
|
||||
)
|
||||
tensor._typed_storage()._resize_(0)
|
||||
return not already_freed
|
||||
|
||||
|
||||
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
|
||||
"""Returns if ``x`` and ``y`` share the same storage."""
|
||||
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
|
||||
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
|
||||
|
||||
|
||||
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
||||
"""This is used as an alternate to ``assert`` when in the backward context
|
||||
to print the error message ``s`` since otherwise, it is swallowed."""
|
||||
if not cond:
|
||||
print(s)
|
||||
traceback.print_stack()
|
||||
if raise_assertion_error:
|
||||
raise AssertionError(s)
|
||||
|
||||
|
||||
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
|
||||
with no_dispatch():
|
||||
tensor.record_stream(cast(torch._C.Stream, stream))
|
||||
|
|
|
|||
|
|
@ -27,15 +27,10 @@ from torch.distributed.fsdp._common_utils import (
|
|||
_set_fsdp_flattened,
|
||||
HandleTrainingState,
|
||||
)
|
||||
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
|
||||
|
||||
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
|
||||
from ._utils import (
|
||||
_alloc_storage,
|
||||
_free_storage,
|
||||
_no_dispatch_record_stream,
|
||||
_same_storage,
|
||||
p_assert,
|
||||
)
|
||||
from ._utils import _no_dispatch_record_stream, _same_storage
|
||||
|
||||
__all__ = [
|
||||
"FlatParameter",
|
||||
|
|
@ -558,7 +553,7 @@ class FlatParamHandle:
|
|||
if not self.uses_sharded_strategy:
|
||||
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
|
||||
else:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
flat_param.storage_offset() == 0,
|
||||
"The `FlatParameter` is not the sole occupant of its storage",
|
||||
)
|
||||
|
|
@ -600,8 +595,8 @@ class FlatParamHandle:
|
|||
"""
|
||||
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
|
||||
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
|
||||
p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
|
||||
p_assert(
|
||||
_p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
|
||||
_p_assert(
|
||||
numel_padded <= sharded_flat_param_numel,
|
||||
f"numel_padded: {numel_padded} "
|
||||
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
|
||||
|
|
@ -792,7 +787,7 @@ class FlatParamHandle:
|
|||
self._orig_param_dtype = flat_param.dtype
|
||||
cpu_device = torch.device("cpu")
|
||||
if self._offload_params:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
flat_param.device == cpu_device,
|
||||
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
|
||||
f"offloading is enabled, not {flat_param.device}",
|
||||
|
|
@ -957,7 +952,7 @@ class FlatParamHandle:
|
|||
# tensor as the all-gather destination to preserve the invariant
|
||||
# that `_full_param_padded` is in the low precision
|
||||
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
|
||||
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
|
||||
)
|
||||
|
|
@ -974,13 +969,13 @@ class FlatParamHandle:
|
|||
``padded_unsharded_flat_param``, and switches to using the all-gathered
|
||||
tensor.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
||||
"Expects a process group and world size to have been set via `shard()`",
|
||||
)
|
||||
sharded_flat_param = self.flat_param.data
|
||||
expected_numel = sharded_flat_param.numel() * self.world_size
|
||||
p_assert(
|
||||
_p_assert(
|
||||
padded_unsharded_flat_param.numel() == expected_numel,
|
||||
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
||||
)
|
||||
|
|
@ -1111,7 +1106,7 @@ class FlatParamHandle:
|
|||
clearing any existing sharded gradient in ``.grad`` to enable computing
|
||||
a new unsharded gradient.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self._training_state
|
||||
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
|
||||
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
|
||||
|
|
@ -1123,7 +1118,7 @@ class FlatParamHandle:
|
|||
):
|
||||
self._check_on_compute_device(self.flat_param)
|
||||
grad_offloaded = flat_param.grad.device != self.device
|
||||
p_assert(
|
||||
_p_assert(
|
||||
not grad_offloaded or self._offload_params,
|
||||
f"Expects the sharded gradient to be on {self.device} "
|
||||
f"but got {flat_param.grad.device}",
|
||||
|
|
@ -1142,7 +1137,7 @@ class FlatParamHandle:
|
|||
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
||||
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
||||
else:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(flat_param, "_cpu_grad"),
|
||||
"`_cpu_grad` should be defined if the gradient is on CPU",
|
||||
)
|
||||
|
|
@ -1162,7 +1157,7 @@ class FlatParamHandle:
|
|||
sharded_grad.data = sharded_grad.to(local_shard_dtype)
|
||||
else:
|
||||
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
flat_param.grad.size() == padded_unsharded_size,
|
||||
"Expects `.grad` to be the unsharded gradient in "
|
||||
f"`no_sync()` with size {padded_unsharded_size} "
|
||||
|
|
@ -1203,7 +1198,7 @@ class FlatParamHandle:
|
|||
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
||||
cast_grad_to_param_dtype_if_needed(flat_param)
|
||||
else:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
not self.uses_sharded_strategy
|
||||
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
||||
"All sharded parameters that received a gradient in the "
|
||||
|
|
@ -1229,7 +1224,7 @@ class FlatParamHandle:
|
|||
Postcondition: Same as the precondition.
|
||||
"""
|
||||
self._check_sharded_strategy()
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
||||
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
||||
)
|
||||
|
|
@ -1242,7 +1237,7 @@ class FlatParamHandle:
|
|||
padded_storage_ptr = (
|
||||
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
|
||||
)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
unpadded_storage_ptr == padded_storage_ptr,
|
||||
"Expects the unpadded parameter to be a view into the padded parameter",
|
||||
)
|
||||
|
|
@ -1251,7 +1246,7 @@ class FlatParamHandle:
|
|||
try:
|
||||
yield
|
||||
finally:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
||||
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
||||
)
|
||||
|
|
@ -1314,7 +1309,7 @@ class FlatParamHandle:
|
|||
flat_param = self.flat_param
|
||||
if self._offload_params:
|
||||
device = flat_param._local_shard.device # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
device == torch.device("cpu"),
|
||||
f"Expects the local shard to be on CPU but got {device}",
|
||||
)
|
||||
|
|
@ -1357,7 +1352,7 @@ class FlatParamHandle:
|
|||
"""
|
||||
if tensor is None:
|
||||
tensor = flat_param
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
|
||||
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
|
||||
f"{tensor.numel()} numel",
|
||||
|
|
@ -1416,7 +1411,7 @@ class FlatParamHandle:
|
|||
# hook fires (e.g. for reentrant AC)
|
||||
assert self.flat_param._tensors is not None # mypy
|
||||
tensor = self.flat_param._tensors[i]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor is not None,
|
||||
"Expects `Tensor` to have been saved in forward",
|
||||
)
|
||||
|
|
@ -1439,14 +1434,14 @@ class FlatParamHandle:
|
|||
) in enumerate(self.flat_param._shared_param_infos):
|
||||
if hasattr(module, param_name):
|
||||
delattr(module, param_name)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(prim_module, prim_param_name),
|
||||
f"Module {prim_module_name} is missing parameter {prim_param_name}",
|
||||
)
|
||||
prim_param: Union[Tensor, nn.Parameter] = getattr(
|
||||
prim_module, prim_param_name
|
||||
)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
not as_params or isinstance(prim_param, nn.Parameter),
|
||||
f"as_params={as_params} type(prim_param)={type(prim_param)}",
|
||||
)
|
||||
|
|
@ -1485,7 +1480,7 @@ class FlatParamHandle:
|
|||
for i, (view, (param_name, module, _)) in enumerate(
|
||||
zip(views, self.flat_param._param_infos)
|
||||
):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(module, param_name),
|
||||
f"{self.flat_param._fqns[i]} is missing",
|
||||
)
|
||||
|
|
@ -1511,7 +1506,7 @@ class FlatParamHandle:
|
|||
prim_module,
|
||||
_,
|
||||
) in enumerate(self.flat_param._shared_param_infos):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
hasattr(module, param_name),
|
||||
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
||||
) # did not save FQN info in `_shared_param_infos`
|
||||
|
|
@ -1793,7 +1788,7 @@ class FlatParamHandle:
|
|||
RuntimeError: If the ``src_tensor`` does not have the expected
|
||||
shape.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
len(expected_shape) == 1,
|
||||
f"Expects a 1D expected shape but got {expected_shape}",
|
||||
)
|
||||
|
|
@ -1935,7 +1930,7 @@ class FlatParamHandle:
|
|||
else:
|
||||
# If in the forward, then there may be an accumulated gradient,
|
||||
# which will be in `.grad`
|
||||
p_assert(
|
||||
_p_assert(
|
||||
flat_param.grad is None
|
||||
or not self.uses_sharded_strategy
|
||||
or self._training_state == HandleTrainingState.FORWARD,
|
||||
|
|
@ -1954,7 +1949,7 @@ class FlatParamHandle:
|
|||
"""
|
||||
if not self._use_orig_params:
|
||||
return
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self._training_state == HandleTrainingState.BACKWARD_POST,
|
||||
"Expects to only be called in the post-backward after gradient computation",
|
||||
)
|
||||
|
|
@ -1971,16 +1966,16 @@ class FlatParamHandle:
|
|||
# CHECKS & INVARIANTS #
|
||||
#######################
|
||||
def _check_sharded_strategy(self):
|
||||
p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
||||
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
||||
|
||||
def _check_on_compute_device(self, tensor: Tensor):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor.device == self.device,
|
||||
f"Expects tensor to be on the compute device {self.device}",
|
||||
)
|
||||
|
||||
def _check_on_cpu(self, tensor: Tensor):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor.device == torch.device("cpu"),
|
||||
f"Expects tensor to be on CPU but got {tensor.device}",
|
||||
)
|
||||
|
|
@ -1988,7 +1983,7 @@ class FlatParamHandle:
|
|||
@staticmethod
|
||||
def _check_storage_freed(tensor: Tensor):
|
||||
storage_size: int = tensor._typed_storage()._size()
|
||||
p_assert(
|
||||
_p_assert(
|
||||
storage_size == 0,
|
||||
f"Expects storage to be freed but got storage with size {storage_size}",
|
||||
)
|
||||
|
|
@ -1996,37 +1991,37 @@ class FlatParamHandle:
|
|||
@staticmethod
|
||||
def _check_storage_allocated(tensor: Tensor):
|
||||
storage_size: int = tensor._typed_storage()._size()
|
||||
p_assert(storage_size > 0, "Expects storage to be allocated")
|
||||
_p_assert(storage_size > 0, "Expects storage to be allocated")
|
||||
|
||||
def _check_low_precision_shard(self):
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self._uses_param_mixed_precision,
|
||||
"Not using low precision for parameters",
|
||||
)
|
||||
p_assert(
|
||||
_p_assert(
|
||||
getattr(self.flat_param, "_mp_shard", None) is not None,
|
||||
"Expects `_mp_shard` to exist",
|
||||
)
|
||||
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
device == self.device,
|
||||
f"Expects the low precision shard to be on {self.device} but got {device}",
|
||||
)
|
||||
|
||||
def _check_unsharded(self, tensor: Tensor):
|
||||
msg_prefix = "Expects tensor to be unsharded "
|
||||
p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
unsharded_size = self.flat_param._unpadded_unsharded_size
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor.size() == unsharded_size,
|
||||
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
||||
)
|
||||
|
||||
def _check_sharded(self, tensor: Tensor):
|
||||
msg_prefix = "Expects tensor to be sharded "
|
||||
p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
||||
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
_p_assert(
|
||||
tensor.size() == sharded_size,
|
||||
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ from torch.distributed.fsdp.api import (
|
|||
StateDictSettings,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.utils import _p_assert
|
||||
|
||||
from ._optim_utils import (
|
||||
_broadcast_pos_dim_tensor_states,
|
||||
|
|
@ -98,7 +99,6 @@ from ._unshard_param_utils import (
|
|||
_unshard_params,
|
||||
_unshard_params_recurse,
|
||||
)
|
||||
from ._utils import p_assert
|
||||
from .flat_param import FlatParameter
|
||||
from .wrap import _FSDPPolicy
|
||||
|
||||
|
|
@ -740,7 +740,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
|
||||
)
|
||||
for handle in self._handles:
|
||||
p_assert(
|
||||
_p_assert(
|
||||
handle.flat_param.device == self.compute_device,
|
||||
"Expected `FlatParameter` to be on the compute device "
|
||||
f"{self.compute_device} but got {handle.flat_param.device}",
|
||||
|
|
@ -830,7 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
this refreshes the sharded views before exiting. This method shouuld
|
||||
only be called when using the original parameters.
|
||||
"""
|
||||
p_assert(
|
||||
_p_assert(
|
||||
self._use_orig_params,
|
||||
"`_deregister_orig_params_ctx()` should only be called when "
|
||||
"`_use_orig_params=True`",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Callable, Union, Set, OrderedDict
|
||||
import dataclasses
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -94,6 +96,92 @@ def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
|
|||
to_map = None # type: ignore[assignment]
|
||||
return res
|
||||
|
||||
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
||||
"""This is used as an alternate to ``assert`` when in the backward context
|
||||
to print the error message ``s`` since otherwise, it is swallowed."""
|
||||
if not cond:
|
||||
print(s)
|
||||
traceback.print_stack()
|
||||
if raise_assertion_error:
|
||||
raise AssertionError(s)
|
||||
|
||||
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
|
||||
"""
|
||||
Allocate storage for ``tensor`` with the given size.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if this method allocated storage and ``False`` if the
|
||||
storage was already allocated.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
already_allocated = tensor._typed_storage()._size() == size.numel()
|
||||
if not already_allocated:
|
||||
tensor_storage_size = tensor._typed_storage()._size()
|
||||
_p_assert(
|
||||
tensor_storage_size == 0,
|
||||
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
|
||||
)
|
||||
tensor._typed_storage()._resize_(size.numel())
|
||||
return not already_allocated
|
||||
|
||||
|
||||
def _free_storage(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Frees the underlying storage of ``tensor``.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the method freed the storage and ``False`` if the
|
||||
storage was already freed.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
already_freed = tensor._typed_storage()._size() == 0
|
||||
if not already_freed:
|
||||
_p_assert(
|
||||
tensor.storage_offset() == 0,
|
||||
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
||||
f"storage offset: {tensor.storage_offset()}\n"
|
||||
f"storage size: {tensor._typed_storage()._size()}\n"
|
||||
f"tensor shape: {tensor.shape}",
|
||||
)
|
||||
tensor._typed_storage()._resize_(0)
|
||||
return not already_freed
|
||||
|
||||
def _apply_to_tensors(
|
||||
fn: Callable,
|
||||
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
|
||||
) -> Any:
|
||||
"""Recursively apply to all tensor in different kinds of container types."""
|
||||
|
||||
def apply(
|
||||
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
|
||||
) -> Any:
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
elif hasattr(x, "__dataclass_fields__"):
|
||||
dc = dataclasses.replace(x)
|
||||
for f in dataclasses.fields(dc):
|
||||
name = f.name
|
||||
setattr(dc, name, apply(getattr(dc, name)))
|
||||
return dc
|
||||
elif isinstance(x, OrderedDict):
|
||||
od = x.__class__()
|
||||
for key, value in x.items():
|
||||
od[key] = apply(value)
|
||||
return od
|
||||
elif isinstance(x, PackedSequence):
|
||||
apply(x.data)
|
||||
return x
|
||||
elif isinstance(x, dict):
|
||||
return {key: apply(value) for key, value in x.items()}
|
||||
elif _is_namedtuple(x):
|
||||
res = (apply(el) for el in x)
|
||||
return type(x)(*res)
|
||||
elif isinstance(x, (list, tuple, set)):
|
||||
return type(x)(apply(el) for el in x)
|
||||
else:
|
||||
return x
|
||||
|
||||
return apply(container)
|
||||
|
||||
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies):
|
||||
inputs = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user