mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 217b37c023.
Reverted https://github.com/pytorch/pytorch/pull/109553 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but those distributed failures look legit and they are failing in trunk https://hud.pytorch.org/pr/109553 ([comment](https://github.com/pytorch/pytorch/pull/109553#issuecomment-1734100546))
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import warnings
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
|
_chunk_tensor,
|
|
_flatten_tensor,
|
|
_pre_load_state_dict,
|
|
_unflatten_tensor,
|
|
)
|
|
|
|
__all__ = ["enable_2d_with_fsdp"]
|
|
|
|
|
|
def enable_2d_with_fsdp() -> bool:
|
|
"""
|
|
The API registers the extension which is needed for Tensor Parallelism (TP)
|
|
to work with FullyShardedDataParallel (FSDP). We first parallelize parameters
|
|
within one module or sub_modules based on a parallelize_plan and will let FSDP
|
|
reshard the local tensor of distributed parameter which is essentially a DTensor.
|
|
|
|
Return:
|
|
A `bool` indicated whether extension registration succeeds or not.
|
|
"""
|
|
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.enable_2d_with_fsdp")
|
|
|
|
try:
|
|
from torch.distributed.fsdp._fsdp_extensions import (
|
|
_set_fsdp_extensions,
|
|
FSDPExtensions,
|
|
)
|
|
|
|
class DTensorExtensions(FSDPExtensions):
|
|
def pre_flatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[Any]]:
|
|
return _flatten_tensor(tensor)
|
|
|
|
def post_unflatten_transform(
|
|
self, tensor: torch.Tensor, param_extension: Any
|
|
) -> torch.Tensor:
|
|
return _unflatten_tensor(tensor, param_extension)
|
|
|
|
def chunk_tensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
|
|
|
|
def pre_load_state_dict_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
return _pre_load_state_dict(tensor)
|
|
|
|
_set_fsdp_extensions(DTensorExtensions())
|
|
return True
|
|
|
|
except BaseException as e:
|
|
warnings.warn(
|
|
"PyTorch doesn't have TensorFlattener extension point available"
|
|
"2D parallelism won't work with FSDP"
|
|
f"exception: {e}"
|
|
)
|
|
return False
|