[Resubmit] helpers to torch.dist.utils (#95025)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
This commit is contained in:
Rohan Varma 2023-02-17 02:55:24 +00:00 committed by PyTorch MergeBot
parent 2aa806608b
commit c43e88665a
7 changed files with 159 additions and 178 deletions

View File

@ -11,10 +11,9 @@ from typing import List
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp._utils import _apply_to_tensors
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _replace_by_prefix
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,

View File

@ -26,11 +26,7 @@ from torch.distributed.fsdp._common_utils import (
TrainingState,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp._utils import (
_apply_to_tensors,
_no_dispatch_record_stream,
p_assert,
)
from torch.distributed.fsdp._utils import _no_dispatch_record_stream
from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.fsdp.flat_param import (
_HandlesKey,
@ -39,7 +35,7 @@ from torch.distributed.fsdp.flat_param import (
HandleShardingStrategy,
HandleTrainingState,
)
from torch.distributed.utils import _to_kwargs
from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
RESHARD_AFTER_FORWARD_STRATEGIES = {
HandleShardingStrategy.FULL_SHARD,
@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs(
attr_name_to_values[attr_name] = set()
for fsdp_state in traversal_utils._get_fsdp_states(root_module):
for attr_name in HOMOGENEOUS_ATTR_NAMES:
p_assert(
_p_assert(
hasattr(fsdp_state, attr_name),
f"FSDP state missing attribute {attr_name}",
)
@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs(
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
p_assert(
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
"set yet or should have been set to `False`",
@ -344,7 +340,7 @@ def _reshard(
"""
if not handles:
return
p_assert(
_p_assert(
len(handles) == len(free_unsharded_flat_params),
"Expects both lists to have equal length but got "
f"{len(handles)} and {len(free_unsharded_flat_params)}",
@ -518,7 +514,7 @@ def _root_pre_forward(
may not be the root. If not, then this method does not do anything.
"""
_lazy_init(state, module)
p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
if not state._is_root:
return args, kwargs
if state.forward_prefetch:
@ -675,7 +671,7 @@ def _post_backward_hook(
# the same `FlatParameter`, the post-backward hook may run multiple
# times in one backward, in which case we permit the state to already
# be in `BACKWARD_POST`.
p_assert(
_p_assert(
handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
@ -855,8 +851,8 @@ def _check_comm_hook(
comm_hook: Any,
comm_hook_state: Any,
) -> None:
p_assert(comm_hook is not None, "Communication hook should not be `None`")
p_assert(
_p_assert(comm_hook is not None, "Communication hook should not be `None`")
_p_assert(
comm_hook_state is not None, "Communication hook state should not be `None`"
)
@ -865,13 +861,13 @@ def _check_grad_to_accumulate(
new_sharded_grad: torch.Tensor,
accumulated_grad: torch.Tensor,
) -> None:
p_assert(
_p_assert(
accumulated_grad.shape == new_sharded_grad.shape,
"Shape mismatch when accumulating gradients: "
f"existing gradient shape={accumulated_grad.shape} "
f"new gradient shape={new_sharded_grad.shape}",
)
p_assert(
_p_assert(
accumulated_grad.device == new_sharded_grad.device,
"Device mismatch when accumulating gradients: "
f"existing gradient device={accumulated_grad.device} "
@ -895,7 +891,7 @@ def _post_backward_final_callback(
This runs at the end of the entire backward pass and should only be called
on the root FSDP instance.
"""
p_assert(
_p_assert(
state._is_root,
"The post-backward callback should only be called on the root FSDP instance",
)
@ -952,7 +948,7 @@ def _catch_all_reshard(
if handles_to_reshard:
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
except Exception as e:
p_assert(
_p_assert(
False,
f"Got exception in the catch-all reshard for {state}: {str(e)}",
raise_assertion_error=False,
@ -969,7 +965,7 @@ def _finalize_params(
flat_param = handle.flat_param
if flat_param.requires_grad:
if hasattr(flat_param, "_post_backward_hook_state"):
p_assert(
_p_assert(
len(flat_param._post_backward_hook_state) == 2,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
)
@ -982,7 +978,7 @@ def _finalize_params(
# sharded gradient from the last synchronized iteration
continue
handle.prepare_gradient_for_optim()
p_assert(
_p_assert(
hasattr(flat_param, "_post_backward_called"),
"Expects `_post_backward_called` to be set on the `FlatParameter`",
)
@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch(
HandleTrainingState.BACKWARD_POST,
HandleTrainingState.FORWARD,
)
p_assert(
_p_assert(
training_state in valid_training_states,
f"Prefetching is only supported in {valid_training_states} but "
f"currently in {training_state}",
@ -1067,9 +1063,9 @@ def _get_training_state(
handles_key: _HandlesKey,
) -> HandleTrainingState:
"""Returns the training state of the handles in ``handles_key``."""
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
_p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
training_states = {handle._training_state for handle in handles_key}
p_assert(
_p_assert(
len(training_states) == 1,
f"Expects uniform training state but got {training_states}",
)
@ -1233,7 +1229,7 @@ def _register_post_backward_hooks(
continue
# Get the `AccumulateGrad` object
temp_flat_param = flat_param.expand_as(flat_param)
p_assert(
_p_assert(
temp_flat_param.grad_fn is not None,
"The `grad_fn` is needed to access the `AccumulateGrad` and "
"register the post-backward hook",
@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback(
backward pass. This should be called from the root FSDP instance at the
beginning of the pre-backward.
"""
p_assert(
_p_assert(
state._is_root,
"Only the root FSDP instance should register the post-backward callback",
)
@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation(
is either ``None`` if buffer mixed precision is not enabled or the buffer
low precision dtype otherwise.
"""
p_assert(state._is_root, "Expects the root to cast buffers")
_p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = []
if _is_composable(state):
@ -1344,7 +1340,7 @@ def _get_buffer_dtypes(
"""
buffer_dtypes: List[torch.dtype] = []
for buffer_name in buffer_names:
p_assert(
_p_assert(
buffer_name in state._buffer_name_to_orig_dtype,
f"{buffer_name} is missing from pre-computed dict on rank "
f"{state.rank}, which only has keys "
@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device(
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
corresponding buffer is only moved to ``device``.
"""
p_assert(
_p_assert(
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
f"`buffer_dtypes` is specified but got {len(buffers)} and "

View File

@ -21,7 +21,7 @@ from torch.distributed.fsdp._runtime_utils import (
_unshard,
_unshard_grads,
)
from ._utils import p_assert
from torch.distributed.utils import _p_assert
from .flat_param import FlatParamHandle
FLAT_PARAM = "_flat_param"
@ -336,7 +336,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
Deregisters the original parameters; registers the ``FlatParameter``.
"""
handles = _module_handles(state, module)
p_assert(
_p_assert(
len(handles) <= 1,
"Expects <=1 handle per FSDP instance; needs to be refactored "
"for >1 handle (e.g. non-recursive wrapping)",
@ -344,7 +344,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
if not handles:
return
handle = handles[0]
p_assert(
_p_assert(
handle._use_orig_params,
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
f"handle: {handle._use_orig_params}",

View File

@ -1,14 +1,7 @@
import dataclasses
import traceback
from collections import OrderedDict
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
from typing import cast
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
_is_namedtuple,
)
from torch.nn.utils.rnn import PackedSequence
from torch.utils._mode_utils import no_dispatch
@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module):
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
def _apply_to_tensors(
fn: Callable,
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
@torch.no_grad()
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
@torch.no_grad()
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""Returns if ``x`` and ``y`` share the same storage."""
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
with no_dispatch():
tensor.record_stream(cast(torch._C.Stream, stream))

View File

@ -27,15 +27,10 @@ from torch.distributed.fsdp._common_utils import (
_set_fsdp_flattened,
HandleTrainingState,
)
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
from ._utils import (
_alloc_storage,
_free_storage,
_no_dispatch_record_stream,
_same_storage,
p_assert,
)
from ._utils import _no_dispatch_record_stream, _same_storage
__all__ = [
"FlatParameter",
@ -558,7 +553,7 @@ class FlatParamHandle:
if not self.uses_sharded_strategy:
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
else:
p_assert(
_p_assert(
flat_param.storage_offset() == 0,
"The `FlatParameter` is not the sole occupant of its storage",
)
@ -600,8 +595,8 @@ class FlatParamHandle:
"""
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
p_assert(
_p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
_p_assert(
numel_padded <= sharded_flat_param_numel,
f"numel_padded: {numel_padded} "
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
@ -792,7 +787,7 @@ class FlatParamHandle:
self._orig_param_dtype = flat_param.dtype
cpu_device = torch.device("cpu")
if self._offload_params:
p_assert(
_p_assert(
flat_param.device == cpu_device,
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
f"offloading is enabled, not {flat_param.device}",
@ -957,7 +952,7 @@ class FlatParamHandle:
# tensor as the all-gather destination to preserve the invariant
# that `_full_param_padded` is in the low precision
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
p_assert(
_p_assert(
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
)
@ -974,13 +969,13 @@ class FlatParamHandle:
``padded_unsharded_flat_param``, and switches to using the all-gathered
tensor.
"""
p_assert(
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
p_assert(
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
@ -1111,7 +1106,7 @@ class FlatParamHandle:
clearing any existing sharded gradient in ``.grad`` to enable computing
a new unsharded gradient.
"""
p_assert(
_p_assert(
self._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
@ -1123,7 +1118,7 @@ class FlatParamHandle:
):
self._check_on_compute_device(self.flat_param)
grad_offloaded = flat_param.grad.device != self.device
p_assert(
_p_assert(
not grad_offloaded or self._offload_params,
f"Expects the sharded gradient to be on {self.device} "
f"but got {flat_param.grad.device}",
@ -1142,7 +1137,7 @@ class FlatParamHandle:
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else:
p_assert(
_p_assert(
hasattr(flat_param, "_cpu_grad"),
"`_cpu_grad` should be defined if the gradient is on CPU",
)
@ -1162,7 +1157,7 @@ class FlatParamHandle:
sharded_grad.data = sharded_grad.to(local_shard_dtype)
else:
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
p_assert(
_p_assert(
flat_param.grad.size() == padded_unsharded_size,
"Expects `.grad` to be the unsharded gradient in "
f"`no_sync()` with size {padded_unsharded_size} "
@ -1203,7 +1198,7 @@ class FlatParamHandle:
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
cast_grad_to_param_dtype_if_needed(flat_param)
else:
p_assert(
_p_assert(
not self.uses_sharded_strategy
or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the "
@ -1229,7 +1224,7 @@ class FlatParamHandle:
Postcondition: Same as the precondition.
"""
self._check_sharded_strategy()
p_assert(
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
@ -1242,7 +1237,7 @@ class FlatParamHandle:
padded_storage_ptr = (
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
)
p_assert(
_p_assert(
unpadded_storage_ptr == padded_storage_ptr,
"Expects the unpadded parameter to be a view into the padded parameter",
)
@ -1251,7 +1246,7 @@ class FlatParamHandle:
try:
yield
finally:
p_assert(
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
@ -1314,7 +1309,7 @@ class FlatParamHandle:
flat_param = self.flat_param
if self._offload_params:
device = flat_param._local_shard.device # type: ignore[attr-defined]
p_assert(
_p_assert(
device == torch.device("cpu"),
f"Expects the local shard to be on CPU but got {device}",
)
@ -1357,7 +1352,7 @@ class FlatParamHandle:
"""
if tensor is None:
tensor = flat_param
p_assert(
_p_assert(
tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
f"{tensor.numel()} numel",
@ -1416,7 +1411,7 @@ class FlatParamHandle:
# hook fires (e.g. for reentrant AC)
assert self.flat_param._tensors is not None # mypy
tensor = self.flat_param._tensors[i]
p_assert(
_p_assert(
tensor is not None,
"Expects `Tensor` to have been saved in forward",
)
@ -1439,14 +1434,14 @@ class FlatParamHandle:
) in enumerate(self.flat_param._shared_param_infos):
if hasattr(module, param_name):
delattr(module, param_name)
p_assert(
_p_assert(
hasattr(prim_module, prim_param_name),
f"Module {prim_module_name} is missing parameter {prim_param_name}",
)
prim_param: Union[Tensor, nn.Parameter] = getattr(
prim_module, prim_param_name
)
p_assert(
_p_assert(
not as_params or isinstance(prim_param, nn.Parameter),
f"as_params={as_params} type(prim_param)={type(prim_param)}",
)
@ -1485,7 +1480,7 @@ class FlatParamHandle:
for i, (view, (param_name, module, _)) in enumerate(
zip(views, self.flat_param._param_infos)
):
p_assert(
_p_assert(
hasattr(module, param_name),
f"{self.flat_param._fqns[i]} is missing",
)
@ -1511,7 +1506,7 @@ class FlatParamHandle:
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
p_assert(
_p_assert(
hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing",
) # did not save FQN info in `_shared_param_infos`
@ -1793,7 +1788,7 @@ class FlatParamHandle:
RuntimeError: If the ``src_tensor`` does not have the expected
shape.
"""
p_assert(
_p_assert(
len(expected_shape) == 1,
f"Expects a 1D expected shape but got {expected_shape}",
)
@ -1935,7 +1930,7 @@ class FlatParamHandle:
else:
# If in the forward, then there may be an accumulated gradient,
# which will be in `.grad`
p_assert(
_p_assert(
flat_param.grad is None
or not self.uses_sharded_strategy
or self._training_state == HandleTrainingState.FORWARD,
@ -1954,7 +1949,7 @@ class FlatParamHandle:
"""
if not self._use_orig_params:
return
p_assert(
_p_assert(
self._training_state == HandleTrainingState.BACKWARD_POST,
"Expects to only be called in the post-backward after gradient computation",
)
@ -1971,16 +1966,16 @@ class FlatParamHandle:
# CHECKS & INVARIANTS #
#######################
def _check_sharded_strategy(self):
p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
def _check_on_compute_device(self, tensor: Tensor):
p_assert(
_p_assert(
tensor.device == self.device,
f"Expects tensor to be on the compute device {self.device}",
)
def _check_on_cpu(self, tensor: Tensor):
p_assert(
_p_assert(
tensor.device == torch.device("cpu"),
f"Expects tensor to be on CPU but got {tensor.device}",
)
@ -1988,7 +1983,7 @@ class FlatParamHandle:
@staticmethod
def _check_storage_freed(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size()
p_assert(
_p_assert(
storage_size == 0,
f"Expects storage to be freed but got storage with size {storage_size}",
)
@ -1996,37 +1991,37 @@ class FlatParamHandle:
@staticmethod
def _check_storage_allocated(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size()
p_assert(storage_size > 0, "Expects storage to be allocated")
_p_assert(storage_size > 0, "Expects storage to be allocated")
def _check_low_precision_shard(self):
p_assert(
_p_assert(
self._uses_param_mixed_precision,
"Not using low precision for parameters",
)
p_assert(
_p_assert(
getattr(self.flat_param, "_mp_shard", None) is not None,
"Expects `_mp_shard` to exist",
)
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
p_assert(
_p_assert(
device == self.device,
f"Expects the low precision shard to be on {self.device} but got {device}",
)
def _check_unsharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be unsharded "
p_assert(tensor is not None, msg_prefix + "but got `None`")
_p_assert(tensor is not None, msg_prefix + "but got `None`")
unsharded_size = self.flat_param._unpadded_unsharded_size
p_assert(
_p_assert(
tensor.size() == unsharded_size,
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
)
def _check_sharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be sharded "
p_assert(tensor is not None, msg_prefix + "but got `None`")
_p_assert(tensor is not None, msg_prefix + "but got `None`")
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
p_assert(
_p_assert(
tensor.size() == sharded_size,
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
)

View File

@ -77,6 +77,7 @@ from torch.distributed.fsdp.api import (
StateDictSettings,
StateDictType,
)
from torch.distributed.utils import _p_assert
from ._optim_utils import (
_broadcast_pos_dim_tensor_states,
@ -98,7 +99,6 @@ from ._unshard_param_utils import (
_unshard_params,
_unshard_params_recurse,
)
from ._utils import p_assert
from .flat_param import FlatParameter
from .wrap import _FSDPPolicy
@ -740,7 +740,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
)
for handle in self._handles:
p_assert(
_p_assert(
handle.flat_param.device == self.compute_device,
"Expected `FlatParameter` to be on the compute device "
f"{self.compute_device} but got {handle.flat_param.device}",
@ -830,7 +830,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
this refreshes the sharded views before exiting. This method shouuld
only be called when using the original parameters.
"""
p_assert(
_p_assert(
self._use_orig_params,
"`_deregister_orig_params_ctx()` should only be called when "
"`_use_orig_params=True`",

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Callable, Union, Set, OrderedDict
import dataclasses
import traceback
import torch
import torch.distributed as dist
@ -94,6 +96,92 @@ def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies):
to_map = None # type: ignore[assignment]
return res
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed
def _apply_to_tensors(
fn: Callable,
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies):
inputs = (