[Resubmit] helpers to torch.dist.utils (#95025)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
This commit is contained in:
Rohan Varma 2023-02-17 02:55:24 +00:00 committed by PyTorch MergeBot
parent 2aa806608b
commit c43e88665a
7 changed files with 159 additions and 178 deletions

View File

@ -11,10 +11,9 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import distributed as dist from torch import distributed as dist
from torch.distributed.fsdp._utils import _apply_to_tensors
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _replace_by_prefix from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,

View File

@ -26,11 +26,7 @@ from torch.distributed.fsdp._common_utils import (
TrainingState, TrainingState,
) )
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp._utils import ( from torch.distributed.fsdp._utils import _no_dispatch_record_stream
_apply_to_tensors,
_no_dispatch_record_stream,
p_assert,
)
from torch.distributed.fsdp.api import BackwardPrefetch from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.fsdp.flat_param import ( from torch.distributed.fsdp.flat_param import (
_HandlesKey, _HandlesKey,
@ -39,7 +35,7 @@ from torch.distributed.fsdp.flat_param import (
HandleShardingStrategy, HandleShardingStrategy,
HandleTrainingState, HandleTrainingState,
) )
from torch.distributed.utils import _to_kwargs from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
RESHARD_AFTER_FORWARD_STRATEGIES = { RESHARD_AFTER_FORWARD_STRATEGIES = {
HandleShardingStrategy.FULL_SHARD, HandleShardingStrategy.FULL_SHARD,
@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs(
attr_name_to_values[attr_name] = set() attr_name_to_values[attr_name] = set()
for fsdp_state in traversal_utils._get_fsdp_states(root_module): for fsdp_state in traversal_utils._get_fsdp_states(root_module):
for attr_name in HOMOGENEOUS_ATTR_NAMES: for attr_name in HOMOGENEOUS_ATTR_NAMES:
p_assert( _p_assert(
hasattr(fsdp_state, attr_name), hasattr(fsdp_state, attr_name),
f"FSDP state missing attribute {attr_name}", f"FSDP state missing attribute {attr_name}",
) )
@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs(
# Relax the assert for non-root FSDP instances in case the nested # Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after # initialized module is wrapped again in FSDP later (e.g. after
# training to run inference) # training to run inference)
p_assert( _p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root, fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been " "Non-root FSDP instance's `_is_root` should not have been "
"set yet or should have been set to `False`", "set yet or should have been set to `False`",
@ -344,7 +340,7 @@ def _reshard(
""" """
if not handles: if not handles:
return return
p_assert( _p_assert(
len(handles) == len(free_unsharded_flat_params), len(handles) == len(free_unsharded_flat_params),
"Expects both lists to have equal length but got " "Expects both lists to have equal length but got "
f"{len(handles)} and {len(free_unsharded_flat_params)}", f"{len(handles)} and {len(free_unsharded_flat_params)}",
@ -518,7 +514,7 @@ def _root_pre_forward(
may not be the root. If not, then this method does not do anything. may not be the root. If not, then this method does not do anything.
""" """
_lazy_init(state, module) _lazy_init(state, module)
p_assert(state._is_root is not None, "Expects a root FSDP to have been set") _p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
if not state._is_root: if not state._is_root:
return args, kwargs return args, kwargs
if state.forward_prefetch: if state.forward_prefetch:
@ -675,7 +671,7 @@ def _post_backward_hook(
# the same `FlatParameter`, the post-backward hook may run multiple # the same `FlatParameter`, the post-backward hook may run multiple
# times in one backward, in which case we permit the state to already # times in one backward, in which case we permit the state to already
# be in `BACKWARD_POST`. # be in `BACKWARD_POST`.
p_assert( _p_assert(
handle._training_state handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
@ -855,8 +851,8 @@ def _check_comm_hook(
comm_hook: Any, comm_hook: Any,
comm_hook_state: Any, comm_hook_state: Any,
) -> None: ) -> None:
p_assert(comm_hook is not None, "Communication hook should not be `None`") _p_assert(comm_hook is not None, "Communication hook should not be `None`")
p_assert( _p_assert(
comm_hook_state is not None, "Communication hook state should not be `None`" comm_hook_state is not None, "Communication hook state should not be `None`"
) )
@ -865,13 +861,13 @@ def _check_grad_to_accumulate(
new_sharded_grad: torch.Tensor, new_sharded_grad: torch.Tensor,
accumulated_grad: torch.Tensor, accumulated_grad: torch.Tensor,
) -> None: ) -> None:
p_assert( _p_assert(
accumulated_grad.shape == new_sharded_grad.shape, accumulated_grad.shape == new_sharded_grad.shape,
"Shape mismatch when accumulating gradients: " "Shape mismatch when accumulating gradients: "
f"existing gradient shape={accumulated_grad.shape} " f"existing gradient shape={accumulated_grad.shape} "
f"new gradient shape={new_sharded_grad.shape}", f"new gradient shape={new_sharded_grad.shape}",
) )
p_assert( _p_assert(
accumulated_grad.device == new_sharded_grad.device, accumulated_grad.device == new_sharded_grad.device,
"Device mismatch when accumulating gradients: " "Device mismatch when accumulating gradients: "
f"existing gradient device={accumulated_grad.device} " f"existing gradient device={accumulated_grad.device} "
@ -895,7 +891,7 @@ def _post_backward_final_callback(
This runs at the end of the entire backward pass and should only be called This runs at the end of the entire backward pass and should only be called
on the root FSDP instance. on the root FSDP instance.
""" """
p_assert( _p_assert(
state._is_root, state._is_root,
"The post-backward callback should only be called on the root FSDP instance", "The post-backward callback should only be called on the root FSDP instance",
) )
@ -952,7 +948,7 @@ def _catch_all_reshard(
if handles_to_reshard: if handles_to_reshard:
_reshard(state, handles_to_reshard, free_unsharded_flat_params) _reshard(state, handles_to_reshard, free_unsharded_flat_params)
except Exception as e: except Exception as e:
p_assert( _p_assert(
False, False,
f"Got exception in the catch-all reshard for {state}: {str(e)}", f"Got exception in the catch-all reshard for {state}: {str(e)}",
raise_assertion_error=False, raise_assertion_error=False,
@ -969,7 +965,7 @@ def _finalize_params(
flat_param = handle.flat_param flat_param = handle.flat_param
if flat_param.requires_grad: if flat_param.requires_grad:
if hasattr(flat_param, "_post_backward_hook_state"): if hasattr(flat_param, "_post_backward_hook_state"):
p_assert( _p_assert(
len(flat_param._post_backward_hook_state) == 2, len(flat_param._post_backward_hook_state) == 2,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
) )
@ -982,7 +978,7 @@ def _finalize_params(
# sharded gradient from the last synchronized iteration # sharded gradient from the last synchronized iteration
continue continue
handle.prepare_gradient_for_optim() handle.prepare_gradient_for_optim()
p_assert( _p_assert(
hasattr(flat_param, "_post_backward_called"), hasattr(flat_param, "_post_backward_called"),
"Expects `_post_backward_called` to be set on the `FlatParameter`", "Expects `_post_backward_called` to be set on the `FlatParameter`",
) )
@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch(
HandleTrainingState.BACKWARD_POST, HandleTrainingState.BACKWARD_POST,
HandleTrainingState.FORWARD, HandleTrainingState.FORWARD,
) )
p_assert( _p_assert(
training_state in valid_training_states, training_state in valid_training_states,
f"Prefetching is only supported in {valid_training_states} but " f"Prefetching is only supported in {valid_training_states} but "
f"currently in {training_state}", f"currently in {training_state}",
@ -1067,9 +1063,9 @@ def _get_training_state(
handles_key: _HandlesKey, handles_key: _HandlesKey,
) -> HandleTrainingState: ) -> HandleTrainingState:
"""Returns the training state of the handles in ``handles_key``.""" """Returns the training state of the handles in ``handles_key``."""
p_assert(len(handles_key) > 0, "Expects a non-empty handles key") _p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
training_states = {handle._training_state for handle in handles_key} training_states = {handle._training_state for handle in handles_key}
p_assert( _p_assert(
len(training_states) == 1, len(training_states) == 1,
f"Expects uniform training state but got {training_states}", f"Expects uniform training state but got {training_states}",
) )
@ -1233,7 +1229,7 @@ def _register_post_backward_hooks(
continue continue
# Get the `AccumulateGrad` object # Get the `AccumulateGrad` object
temp_flat_param = flat_param.expand_as(flat_param) temp_flat_param = flat_param.expand_as(flat_param)
p_assert( _p_assert(
temp_flat_param.grad_fn is not None, temp_flat_param.grad_fn is not None,
"The `grad_fn` is needed to access the `AccumulateGrad` and " "The `grad_fn` is needed to access the `AccumulateGrad` and "
"register the post-backward hook", "register the post-backward hook",
@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback(
backward pass. This should be called from the root FSDP instance at the backward pass. This should be called from the root FSDP instance at the
beginning of the pre-backward. beginning of the pre-backward.
""" """
p_assert( _p_assert(
state._is_root, state._is_root,
"Only the root FSDP instance should register the post-backward callback", "Only the root FSDP instance should register the post-backward callback",
) )
@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation(
is either ``None`` if buffer mixed precision is not enabled or the buffer is either ``None`` if buffer mixed precision is not enabled or the buffer
low precision dtype otherwise. low precision dtype otherwise.
""" """
p_assert(state._is_root, "Expects the root to cast buffers") _p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = [] buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = [] buffer_dtypes: List[Optional[torch.dtype]] = []
if _is_composable(state): if _is_composable(state):
@ -1344,7 +1340,7 @@ def _get_buffer_dtypes(
""" """
buffer_dtypes: List[torch.dtype] = [] buffer_dtypes: List[torch.dtype] = []
for buffer_name in buffer_names: for buffer_name in buffer_names:
p_assert( _p_assert(
buffer_name in state._buffer_name_to_orig_dtype, buffer_name in state._buffer_name_to_orig_dtype,
f"{buffer_name} is missing from pre-computed dict on rank " f"{buffer_name} is missing from pre-computed dict on rank "
f"{state.rank}, which only has keys " f"{state.rank}, which only has keys "
@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device(
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
corresponding buffer is only moved to ``device``. corresponding buffer is only moved to ``device``.
""" """
p_assert( _p_assert(
buffer_dtypes is None or len(buffers) == len(buffer_dtypes), buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
f"Expects `buffers` and `buffer_dtypes` to have the same length if " f"Expects `buffers` and `buffer_dtypes` to have the same length if "
f"`buffer_dtypes` is specified but got {len(buffers)} and " f"`buffer_dtypes` is specified but got {len(buffers)} and "

View File

@ -21,7 +21,7 @@ from torch.distributed.fsdp._runtime_utils import (
_unshard, _unshard,
_unshard_grads, _unshard_grads,
) )
from ._utils import p_assert from torch.distributed.utils import _p_assert
from .flat_param import FlatParamHandle from .flat_param import FlatParamHandle
FLAT_PARAM = "_flat_param" FLAT_PARAM = "_flat_param"
@ -336,7 +336,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
Deregisters the original parameters; registers the ``FlatParameter``. Deregisters the original parameters; registers the ``FlatParameter``.
""" """
handles = _module_handles(state, module) handles = _module_handles(state, module)
p_assert( _p_assert(
len(handles) <= 1, len(handles) <= 1,
"Expects <=1 handle per FSDP instance; needs to be refactored " "Expects <=1 handle per FSDP instance; needs to be refactored "
"for >1 handle (e.g. non-recursive wrapping)", "for >1 handle (e.g. non-recursive wrapping)",
@ -344,7 +344,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
if not handles: if not handles:
return return
handle = handles[0] handle = handles[0]
p_assert( _p_assert(
handle._use_orig_params, handle._use_orig_params,
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
f"handle: {handle._use_orig_params}", f"handle: {handle._use_orig_params}",

View File

@ -1,14 +1,7 @@
import dataclasses from typing import cast
import traceback
from collections import OrderedDict
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
import torch import torch
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
_is_namedtuple,
)
from torch.nn.utils.rnn import PackedSequence
from torch.utils._mode_utils import no_dispatch from torch.utils._mode_utils import no_dispatch
@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module):
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
def _apply_to_tensors(
fn: Callable,
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
@torch.no_grad()
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
@torch.no_grad()
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""Returns if ``x`` and ``y`` share the same storage.""" """Returns if ``x`` and ``y`` share the same storage."""
# NOTE: CPU and GPU tensors are ensured to have different data pointers. # NOTE: CPU and GPU tensors are ensured to have different data pointers.
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr() return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None: def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
with no_dispatch(): with no_dispatch():
tensor.record_stream(cast(torch._C.Stream, stream)) tensor.record_stream(cast(torch._C.Stream, stream))

View File

@ -27,15 +27,10 @@ from torch.distributed.fsdp._common_utils import (
_set_fsdp_flattened, _set_fsdp_flattened,
HandleTrainingState, HandleTrainingState,
) )
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
from ._utils import ( from ._utils import _no_dispatch_record_stream, _same_storage
_alloc_storage,
_free_storage,
_no_dispatch_record_stream,
_same_storage,
p_assert,
)
__all__ = [ __all__ = [
"FlatParameter", "FlatParameter",
@ -558,7 +553,7 @@ class FlatParamHandle:
if not self.uses_sharded_strategy: if not self.uses_sharded_strategy:
self._init_shard_metadata(0, 0, flat_param.numel() - 1) self._init_shard_metadata(0, 0, flat_param.numel() - 1)
else: else:
p_assert( _p_assert(
flat_param.storage_offset() == 0, flat_param.storage_offset() == 0,
"The `FlatParameter` is not the sole occupant of its storage", "The `FlatParameter` is not the sole occupant of its storage",
) )
@ -600,8 +595,8 @@ class FlatParamHandle:
""" """
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined] self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded` sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
p_assert(start >= 0 and start <= end, f"start: {start} end: {end}") _p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
p_assert( _p_assert(
numel_padded <= sharded_flat_param_numel, numel_padded <= sharded_flat_param_numel,
f"numel_padded: {numel_padded} " f"numel_padded: {numel_padded} "
f"sharded_flat_param_numel: {sharded_flat_param_numel}", f"sharded_flat_param_numel: {sharded_flat_param_numel}",
@ -792,7 +787,7 @@ class FlatParamHandle:
self._orig_param_dtype = flat_param.dtype self._orig_param_dtype = flat_param.dtype
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
if self._offload_params: if self._offload_params:
p_assert( _p_assert(
flat_param.device == cpu_device, flat_param.device == cpu_device,
f"Expects the `FlatParameter` to be on CPU when parameter CPU " f"Expects the `FlatParameter` to be on CPU when parameter CPU "
f"offloading is enabled, not {flat_param.device}", f"offloading is enabled, not {flat_param.device}",
@ -957,7 +952,7 @@ class FlatParamHandle:
# tensor as the all-gather destination to preserve the invariant # tensor as the all-gather destination to preserve the invariant
# that `_full_param_padded` is in the low precision # that `_full_param_padded` is in the low precision
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
p_assert( _p_assert(
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
f"Expects full precision but got {self._fwd_bwd_param_dtype}", f"Expects full precision but got {self._fwd_bwd_param_dtype}",
) )
@ -974,13 +969,13 @@ class FlatParamHandle:
``padded_unsharded_flat_param``, and switches to using the all-gathered ``padded_unsharded_flat_param``, and switches to using the all-gathered
tensor. tensor.
""" """
p_assert( _p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"), hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`", "Expects a process group and world size to have been set via `shard()`",
) )
sharded_flat_param = self.flat_param.data sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size expected_numel = sharded_flat_param.numel() * self.world_size
p_assert( _p_assert(
padded_unsharded_flat_param.numel() == expected_numel, padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
) )
@ -1111,7 +1106,7 @@ class FlatParamHandle:
clearing any existing sharded gradient in ``.grad`` to enable computing clearing any existing sharded gradient in ``.grad`` to enable computing
a new unsharded gradient. a new unsharded gradient.
""" """
p_assert( _p_assert(
self._training_state self._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
@ -1123,7 +1118,7 @@ class FlatParamHandle:
): ):
self._check_on_compute_device(self.flat_param) self._check_on_compute_device(self.flat_param)
grad_offloaded = flat_param.grad.device != self.device grad_offloaded = flat_param.grad.device != self.device
p_assert( _p_assert(
not grad_offloaded or self._offload_params, not grad_offloaded or self._offload_params,
f"Expects the sharded gradient to be on {self.device} " f"Expects the sharded gradient to be on {self.device} "
f"but got {flat_param.grad.device}", f"but got {flat_param.grad.device}",
@ -1142,7 +1137,7 @@ class FlatParamHandle:
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else: else:
p_assert( _p_assert(
hasattr(flat_param, "_cpu_grad"), hasattr(flat_param, "_cpu_grad"),
"`_cpu_grad` should be defined if the gradient is on CPU", "`_cpu_grad` should be defined if the gradient is on CPU",
) )
@ -1162,7 +1157,7 @@ class FlatParamHandle:
sharded_grad.data = sharded_grad.to(local_shard_dtype) sharded_grad.data = sharded_grad.to(local_shard_dtype)
else: else:
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
p_assert( _p_assert(
flat_param.grad.size() == padded_unsharded_size, flat_param.grad.size() == padded_unsharded_size,
"Expects `.grad` to be the unsharded gradient in " "Expects `.grad` to be the unsharded gradient in "
f"`no_sync()` with size {padded_unsharded_size} " f"`no_sync()` with size {padded_unsharded_size} "
@ -1203,7 +1198,7 @@ class FlatParamHandle:
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
cast_grad_to_param_dtype_if_needed(flat_param) cast_grad_to_param_dtype_if_needed(flat_param)
else: else:
p_assert( _p_assert(
not self.uses_sharded_strategy not self.uses_sharded_strategy
or not flat_param._post_backward_called, # type: ignore[attr-defined] or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the " "All sharded parameters that received a gradient in the "
@ -1229,7 +1224,7 @@ class FlatParamHandle:
Postcondition: Same as the precondition. Postcondition: Same as the precondition.
""" """
self._check_sharded_strategy() self._check_sharded_strategy()
p_assert( _p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size, self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
) )
@ -1242,7 +1237,7 @@ class FlatParamHandle:
padded_storage_ptr = ( padded_storage_ptr = (
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr() self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
) )
p_assert( _p_assert(
unpadded_storage_ptr == padded_storage_ptr, unpadded_storage_ptr == padded_storage_ptr,
"Expects the unpadded parameter to be a view into the padded parameter", "Expects the unpadded parameter to be a view into the padded parameter",
) )
@ -1251,7 +1246,7 @@ class FlatParamHandle:
try: try:
yield yield
finally: finally:
p_assert( _p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size, self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
) )
@ -1314,7 +1309,7 @@ class FlatParamHandle:
flat_param = self.flat_param flat_param = self.flat_param
if self._offload_params: if self._offload_params:
device = flat_param._local_shard.device # type: ignore[attr-defined] device = flat_param._local_shard.device # type: ignore[attr-defined]
p_assert( _p_assert(
device == torch.device("cpu"), device == torch.device("cpu"),
f"Expects the local shard to be on CPU but got {device}", f"Expects the local shard to be on CPU but got {device}",
) )
@ -1357,7 +1352,7 @@ class FlatParamHandle:
""" """
if tensor is None: if tensor is None:
tensor = flat_param tensor = flat_param
p_assert( _p_assert(
tensor.numel() == flat_param._unpadded_unsharded_size.numel(), tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got " f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
f"{tensor.numel()} numel", f"{tensor.numel()} numel",
@ -1416,7 +1411,7 @@ class FlatParamHandle:
# hook fires (e.g. for reentrant AC) # hook fires (e.g. for reentrant AC)
assert self.flat_param._tensors is not None # mypy assert self.flat_param._tensors is not None # mypy
tensor = self.flat_param._tensors[i] tensor = self.flat_param._tensors[i]
p_assert( _p_assert(
tensor is not None, tensor is not None,
"Expects `Tensor` to have been saved in forward", "Expects `Tensor` to have been saved in forward",
) )
@ -1439,14 +1434,14 @@ class FlatParamHandle:
) in enumerate(self.flat_param._shared_param_infos): ) in enumerate(self.flat_param._shared_param_infos):
if hasattr(module, param_name): if hasattr(module, param_name):
delattr(module, param_name) delattr(module, param_name)
p_assert( _p_assert(
hasattr(prim_module, prim_param_name), hasattr(prim_module, prim_param_name),
f"Module {prim_module_name} is missing parameter {prim_param_name}", f"Module {prim_module_name} is missing parameter {prim_param_name}",
) )
prim_param: Union[Tensor, nn.Parameter] = getattr( prim_param: Union[Tensor, nn.Parameter] = getattr(
prim_module, prim_param_name prim_module, prim_param_name
) )
p_assert( _p_assert(
not as_params or isinstance(prim_param, nn.Parameter), not as_params or isinstance(prim_param, nn.Parameter),
f"as_params={as_params} type(prim_param)={type(prim_param)}", f"as_params={as_params} type(prim_param)={type(prim_param)}",
) )
@ -1485,7 +1480,7 @@ class FlatParamHandle:
for i, (view, (param_name, module, _)) in enumerate( for i, (view, (param_name, module, _)) in enumerate(
zip(views, self.flat_param._param_infos) zip(views, self.flat_param._param_infos)
): ):
p_assert( _p_assert(
hasattr(module, param_name), hasattr(module, param_name),
f"{self.flat_param._fqns[i]} is missing", f"{self.flat_param._fqns[i]} is missing",
) )
@ -1511,7 +1506,7 @@ class FlatParamHandle:
prim_module, prim_module,
_, _,
) in enumerate(self.flat_param._shared_param_infos): ) in enumerate(self.flat_param._shared_param_infos):
p_assert( _p_assert(
hasattr(module, param_name), hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing", f"{module_name + '.' + param_name if module_name else param_name} is missing",
) # did not save FQN info in `_shared_param_infos` ) # did not save FQN info in `_shared_param_infos`
@ -1793,7 +1788,7 @@ class FlatParamHandle:
RuntimeError: If the ``src_tensor`` does not have the expected RuntimeError: If the ``src_tensor`` does not have the expected
shape. shape.
""" """
p_assert( _p_assert(
len(expected_shape) == 1, len(expected_shape) == 1,
f"Expects a 1D expected shape but got {expected_shape}", f"Expects a 1D expected shape but got {expected_shape}",
) )
@ -1935,7 +1930,7 @@ class FlatParamHandle:
else: else:
# If in the forward, then there may be an accumulated gradient, # If in the forward, then there may be an accumulated gradient,
# which will be in `.grad` # which will be in `.grad`
p_assert( _p_assert(
flat_param.grad is None flat_param.grad is None
or not self.uses_sharded_strategy or not self.uses_sharded_strategy
or self._training_state == HandleTrainingState.FORWARD, or self._training_state == HandleTrainingState.FORWARD,
@ -1954,7 +1949,7 @@ class FlatParamHandle:
""" """
if not self._use_orig_params: if not self._use_orig_params:
return return
p_assert( _p_assert(
self._training_state == HandleTrainingState.BACKWARD_POST, self._training_state == HandleTrainingState.BACKWARD_POST,
"Expects to only be called in the post-backward after gradient computation", "Expects to only be called in the post-backward after gradient computation",
) )
@ -1971,16 +1966,16 @@ class FlatParamHandle:
# CHECKS & INVARIANTS # # CHECKS & INVARIANTS #
####################### #######################
def _check_sharded_strategy(self): def _check_sharded_strategy(self):
p_assert(self.uses_sharded_strategy, "Expects sharded strategy") _p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
def _check_on_compute_device(self, tensor: Tensor): def _check_on_compute_device(self, tensor: Tensor):
p_assert( _p_assert(
tensor.device == self.device, tensor.device == self.device,
f"Expects tensor to be on the compute device {self.device}", f"Expects tensor to be on the compute device {self.device}",
) )
def _check_on_cpu(self, tensor: Tensor): def _check_on_cpu(self, tensor: Tensor):
p_assert( _p_assert(
tensor.device == torch.device("cpu"), tensor.device == torch.device("cpu"),
f"Expects tensor to be on CPU but got {tensor.device}", f"Expects tensor to be on CPU but got {tensor.device}",
) )
@ -1988,7 +1983,7 @@ class FlatParamHandle:
@staticmethod @staticmethod
def _check_storage_freed(tensor: Tensor): def _check_storage_freed(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size() storage_size: int = tensor._typed_storage()._size()
p_assert( _p_assert(
storage_size == 0, storage_size == 0,
f"Expects storage to be freed but got storage with size {storage_size}", f"Expects storage to be freed but got storage with size {storage_size}",
) )
@ -1996,37 +1991,37 @@ class FlatParamHandle:
@staticmethod @staticmethod
def _check_storage_allocated(tensor: Tensor): def _check_storage_allocated(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size() storage_size: int = tensor._typed_storage()._size()
p_assert(storage_size > 0, "Expects storage to be allocated") _p_assert(storage_size > 0, "Expects storage to be allocated")
def _check_low_precision_shard(self): def _check_low_precision_shard(self):
p_assert( _p_assert(
self._uses_param_mixed_precision, self._uses_param_mixed_precision,
"Not using low precision for parameters", "Not using low precision for parameters",
) )
p_assert( _p_assert(
getattr(self.flat_param, "_mp_shard", None) is not None, getattr(self.flat_param, "_mp_shard", None) is not None,
"Expects `_mp_shard` to exist", "Expects `_mp_shard` to exist",
) )
device = self.flat_param._mp_shard.device # type: ignore[attr-defined] device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
p_assert( _p_assert(
device == self.device, device == self.device,
f"Expects the low precision shard to be on {self.device} but got {device}", f"Expects the low precision shard to be on {self.device} but got {device}",
) )
def _check_unsharded(self, tensor: Tensor): def _check_unsharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be unsharded " msg_prefix = "Expects tensor to be unsharded "
p_assert(tensor is not None, msg_prefix + "but got `None`") _p_assert(tensor is not None, msg_prefix + "but got `None`")
unsharded_size = self.flat_param._unpadded_unsharded_size unsharded_size = self.flat_param._unpadded_unsharded_size
p_assert( _p_assert(
tensor.size() == unsharded_size, tensor.size() == unsharded_size,
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
) )
def _check_sharded(self, tensor: Tensor): def _check_sharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be sharded " msg_prefix = "Expects tensor to be sharded "
p_assert(tensor is not None, msg_prefix + "but got `None`") _p_assert(tensor is not None, msg_prefix + "but got `None`")
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
p_assert( _p_assert(
tensor.size() == sharded_size, tensor.size() == sharded_size,
msg_prefix + f"with size {sharded_size} but got {tensor.size()}", msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
) )

View File

@ -77,6 +77,7 @@ from torch.distributed.fsdp.api import (
StateDictSettings, StateDictSettings,
StateDictType, StateDictType,
) )
from torch.distributed.utils import _p_assert
from ._optim_utils import ( from ._optim_utils import (
_broadcast_pos_dim_tensor_states, _broadcast_pos_dim_tensor_states,
@ -98,7 +99,6 @@ from ._unshard_param_utils import (
_unshard_params, _unshard_params,
_unshard_params_recurse, _unshard_params_recurse,
) )
from ._utils import p_assert
from .flat_param import FlatParameter from .flat_param import FlatParameter
from .wrap import _FSDPPolicy from .wrap import _FSDPPolicy
@ -740,7 +740,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
) )
for handle in self._handles: for handle in self._handles:
p_assert( _p_assert(
handle.flat_param.device == self.compute_device, handle.flat_param.device == self.compute_device,
"Expected `FlatParameter` to be on the compute device " "Expected `FlatParameter` to be on the compute device "
f"{self.compute_device} but got {handle.flat_param.device}", f"{self.compute_device} but got {handle.flat_param.device}",
@ -830,7 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
this refreshes the sharded views before exiting. This method shouuld this refreshes the sharded views before exiting. This method shouuld
only be called when using the original parameters. only be called when using the original parameters.
""" """
p_assert( _p_assert(
self._use_orig_params, self._use_orig_params,
"`_deregister_orig_params_ctx()` should only be called when " "`_deregister_orig_params_ctx()` should only be called when "
"`_use_orig_params=True`", "`_use_orig_params=True`",

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Callable, Union, Set, OrderedDict
import dataclasses
import traceback
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -94,6 +96,92 @@ def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
to_map = None # type: ignore[assignment] to_map = None # type: ignore[assignment]
return res return res
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed
def _apply_to_tensors(
fn: Callable,
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies): def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies):
inputs = ( inputs = (