mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory. **Test Plan** CI **Notes** For VSCode users, - Install `ufmt`: https://pypi.org/project/ufmt/ - Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt - Include in `settings.json`: ``` { "[python]": { "editor.defaultFormatter": "omnilib.ufmt", "editor.formatOnSave": true, }, } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811 Approved by: https://github.com/rohan-varma, https://github.com/fegin
116 lines
3.1 KiB
Python
116 lines
3.1 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
|
|
|
|
|
class FSDPExtensions(ABC):
|
|
"""
|
|
This enables some customizable hooks to enable composability with tensor
|
|
parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
|
|
set a custom :class:`FSDPExtensions` that implements the hooks.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def pre_flatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[Any]]:
|
|
"""E.g. converting ``DistributedTensor`` to local tensor."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def post_unflatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
param_extension: Any,
|
|
) -> torch.Tensor:
|
|
"""E.g. converting local tensor to ``DistributedTensor``."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def chunk_tensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
) -> torch.Tensor:
|
|
"""Shards a tensor to chunks and returns the local chunk."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def pre_load_state_dict_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
"""
|
|
This is to be called before loading a *sharded* model state dict and
|
|
should return the tensor and list of shards from which to load data.
|
|
"""
|
|
...
|
|
|
|
|
|
_extensions: Optional[FSDPExtensions] = None
|
|
|
|
|
|
def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
|
|
global _extensions
|
|
_extensions = flattener
|
|
|
|
|
|
def _ext_pre_flatten_transform(
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[Any]]:
|
|
if _extensions is not None:
|
|
new_tensor, extension = _extensions.pre_flatten_transform(tensor)
|
|
if extension is not None:
|
|
return new_tensor, extension
|
|
return tensor, None
|
|
|
|
|
|
def _ext_post_unflatten_transform(
|
|
tensor: torch.Tensor,
|
|
param_extension: Any,
|
|
) -> torch.Tensor:
|
|
if _extensions is not None and param_extension is not None:
|
|
return _extensions.post_unflatten_transform(tensor, param_extension)
|
|
return tensor
|
|
|
|
|
|
def _ext_chunk_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
) -> torch.Tensor:
|
|
chunk_tensor_fn = (
|
|
_extensions.chunk_tensor
|
|
if _extensions is not None
|
|
else _create_chunk_sharded_tensor
|
|
)
|
|
return chunk_tensor_fn(
|
|
tensor,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
|
|
def _ext_pre_load_state_dict_transform(
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
if _extensions is not None:
|
|
return _extensions.pre_load_state_dict_transform(tensor)
|
|
|
|
assert type(tensor) is ShardedTensor
|
|
shards = tensor.local_shards()
|
|
return (tensor, shards)
|