mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
180 lines
4.9 KiB
Python
180 lines
4.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
|
from torch.distributed.fsdp._shard_utils import (
|
|
_all_gather_dtensor,
|
|
_create_chunk_dtensor,
|
|
_create_chunk_sharded_tensor,
|
|
)
|
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
|
|
|
|
|
class FSDPExtensions(ABC):
|
|
"""
|
|
This enables some customizable hooks to enable composability with tensor
|
|
parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
|
|
set a custom :class:`FSDPExtensions` that implements the hooks.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def pre_flatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[Any]]:
|
|
"""E.g. converting ``DistributedTensor`` to local tensor."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def post_unflatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
param_extension: Any,
|
|
) -> torch.Tensor:
|
|
"""E.g. converting local tensor to ``DistributedTensor``."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def chunk_tensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
"""Shards a tensor to chunks and returns the local chunk."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def chunk_dtensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> torch.Tensor:
|
|
"""Shards a tensor/DTensor to DTensor and returns the local DTensor."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def pre_load_state_dict_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
"""
|
|
This is to be called before loading a *sharded* model state dict and
|
|
should return the tensor and list of shards from which to load data.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def all_gather_dtensor(
|
|
self,
|
|
tensor: DTensor,
|
|
parent_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
"""
|
|
This is to be called before loading a *sharded* DTensor state dict.
|
|
This gathers tensor in FSDP dimension and returns local tensor of
|
|
TP DTensor.
|
|
"""
|
|
...
|
|
|
|
|
|
_extensions: Optional[FSDPExtensions] = None
|
|
|
|
|
|
def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
|
|
global _extensions
|
|
_extensions = flattener
|
|
|
|
|
|
def _ext_pre_flatten_transform(
|
|
tensor: torch.Tensor,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> Tuple[torch.Tensor, Optional[Any]]:
|
|
if fsdp_extension is not None:
|
|
new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
|
|
if param_extension is not None:
|
|
return new_tensor, param_extension
|
|
return tensor, None
|
|
|
|
|
|
def _ext_post_unflatten_transform(
|
|
tensor: torch.Tensor,
|
|
param_extension: Any,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> torch.Tensor:
|
|
if fsdp_extension is not None and param_extension is not None:
|
|
return fsdp_extension.post_unflatten_transform(tensor, param_extension)
|
|
return tensor
|
|
|
|
|
|
def _ext_chunk_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> torch.Tensor:
|
|
chunk_tensor_fn = (
|
|
fsdp_extension.chunk_tensor
|
|
if fsdp_extension is not None
|
|
else _create_chunk_sharded_tensor
|
|
)
|
|
return chunk_tensor_fn(
|
|
tensor,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
|
|
def _ext_chunk_dtensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> torch.Tensor:
|
|
chunk_dtensor_fn = (
|
|
fsdp_extension.chunk_dtensor
|
|
if fsdp_extension is not None
|
|
else _create_chunk_dtensor
|
|
)
|
|
return chunk_dtensor_fn(
|
|
tensor,
|
|
rank,
|
|
device_mesh,
|
|
)
|
|
|
|
|
|
def _ext_pre_load_state_dict_transform(
|
|
tensor: torch.Tensor,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
if fsdp_extension is not None:
|
|
return fsdp_extension.pre_load_state_dict_transform(tensor)
|
|
|
|
assert type(tensor) is ShardedTensor
|
|
shards = tensor.local_shards()
|
|
return (tensor, shards)
|
|
|
|
|
|
def _ext_all_gather_dtensor(
|
|
tensor: DTensor,
|
|
parent_mesh: Optional[DeviceMesh],
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
) -> torch.Tensor:
|
|
all_gather_dtensor_fn = (
|
|
fsdp_extension.all_gather_dtensor
|
|
if fsdp_extension is not None
|
|
else _all_gather_dtensor
|
|
)
|
|
return all_gather_dtensor_fn(tensor, parent_mesh)
|