mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
cmd: `ufmt format torch/distributed/tensor` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89969 Approved by: https://github.com/fduwjj
243 lines
8.6 KiB
Python
243 lines
8.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
|
from torch.distributed.tensor.parallel._utils import (
|
|
_prepare_input_validate,
|
|
_prepare_output_validate,
|
|
_PrepareInputType,
|
|
_PrepareOutputType,
|
|
)
|
|
|
|
__all__ = [
|
|
"ParallelStyle",
|
|
"RowwiseParallel",
|
|
"ColwiseParallel",
|
|
"PairwiseParallel",
|
|
"make_input_replicate_1d",
|
|
"make_input_shard_1d",
|
|
"make_input_shard_1d_dim_last",
|
|
"make_output_replicate_1d",
|
|
"make_output_tensor",
|
|
"make_output_shard_1d",
|
|
]
|
|
|
|
|
|
class ParallelStyle(ABC):
|
|
"""
|
|
The parallel style user wants the module or submodule to be parallelized.
|
|
Users can extend this class to build their own parallel style with customized input/output preparations.
|
|
"""
|
|
|
|
_prepare_input: _PrepareInputType
|
|
_prepare_output: _PrepareOutputType
|
|
|
|
@abstractmethod
|
|
def __init__(self, _prepare_input, _prepare_output) -> None:
|
|
self._prepare_input = _prepare_input # type: ignore[assignment, misc]
|
|
self._prepare_output = _prepare_output # type: ignore[assignment, misc]
|
|
|
|
|
|
class PairwiseParallel(ParallelStyle):
|
|
"""
|
|
PairwiseParallel concatenate colwise and rowwise styles as a fixed
|
|
pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing.
|
|
We assume both input and output needs to a replicate DTensor.
|
|
|
|
.. warning::
|
|
PairwiseParallel only supports ``nn.Multihead Attention``,
|
|
``nn.Transformer`` or even-number-layer MLP for now.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(make_input_replicate_1d, make_output_tensor)
|
|
|
|
|
|
class RowwiseParallel(ParallelStyle):
|
|
"""
|
|
Partitioning the row of a module.
|
|
We assume the input to be a sharded :class:`DTensor` and output to be a replicated :class:`DTensor`.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(make_input_shard_1d_dim_last, make_output_replicate_1d)
|
|
|
|
|
|
class ColwiseParallel(ParallelStyle):
|
|
"""
|
|
Partitioning the column of a tensor or module.
|
|
We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`DTensor`.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(make_input_replicate_1d, make_output_replicate_1d)
|
|
|
|
|
|
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
|
|
def make_input_shard_1d(
|
|
input: Union[torch.Tensor, DTensor],
|
|
device_mesh: Optional[DeviceMesh] = None,
|
|
dim: int = 0,
|
|
) -> DTensor:
|
|
"""
|
|
Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle.
|
|
|
|
Args:
|
|
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
Single tensor will be sharded on dimension ``dim``
|
|
over the 1-D :class:`DeviceMesh`.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
The 1-D device mesh where ``input`` will be sharded.
|
|
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
`input.device_mesh` will be used.
|
|
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
Default: ``None``
|
|
dim (int, optional): The sharding dimension of ``input`` tensor.
|
|
Default: 0
|
|
|
|
Returns:
|
|
A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
|
|
"""
|
|
shard_spec = [Shard(dim)]
|
|
if isinstance(input, DTensor):
|
|
return input.redistribute(device_mesh, shard_spec)
|
|
elif isinstance(input, torch.Tensor):
|
|
return DTensor.from_local(input, device_mesh, shard_spec, run_check=False)
|
|
else:
|
|
raise RuntimeError(
|
|
"Tensor parallel module expects torch.Tensor or DTensor input but"
|
|
f" received {type(input)}!"
|
|
)
|
|
|
|
|
|
def make_input_shard_1d_dim_last(
|
|
input: Union[torch.Tensor, DTensor],
|
|
device_mesh: Optional[DeviceMesh] = None,
|
|
) -> DTensor:
|
|
"""
|
|
Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1.
|
|
|
|
Args:
|
|
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
This single tensor will be sharded on dimension ``dim``
|
|
over the 1-D :class:`DeviceMesh`.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
The 1-D device mesh where ``input`` will be sharded.
|
|
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
`input.device_mesh` will be used.
|
|
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
Default: ``None``
|
|
|
|
Returns:
|
|
A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
|
|
"""
|
|
return make_input_shard_1d(input, device_mesh, dim=-1) # type: ignore[call-arg]
|
|
|
|
|
|
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
|
|
def make_input_replicate_1d(
|
|
input: Union[torch.Tensor, DTensor],
|
|
device_mesh: Optional[DeviceMesh] = None,
|
|
) -> DTensor:
|
|
"""
|
|
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
|
|
|
|
Args:
|
|
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
This input tensor will be replicated over the 1-D :class:`DeviceMesh`.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
The 1-D device mesh where ``input`` will be replicated.
|
|
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
``input.device_mesh`` will be used.
|
|
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
Default: ``None``
|
|
|
|
Returns:
|
|
A :class:`DTensor` replicated over ``device_mesh``.
|
|
"""
|
|
replicate = [Replicate()]
|
|
if isinstance(input, DTensor):
|
|
return input.redistribute(device_mesh, replicate)
|
|
elif isinstance(input, torch.Tensor):
|
|
return DTensor.from_local(input, device_mesh, replicate, run_check=False)
|
|
else:
|
|
raise RuntimeError(
|
|
"Tensor parallel module expects torch.Tensor or DTensor input but"
|
|
f" received {type(input)}!"
|
|
)
|
|
|
|
|
|
@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
|
|
def make_output_shard_1d(
|
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
|
|
) -> DTensor:
|
|
"""
|
|
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
|
|
|
|
Args:
|
|
output (:class:`DTensor`):
|
|
Output of module to be converted.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
Object needed to shard the output and it needs to be a 1D ``device_mesh``
|
|
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
|
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
|
Default: ``None``
|
|
dim (int): Sharding dim for output. Default: 0
|
|
|
|
Return:
|
|
A :class:`DTensor` object sharded on the given dim.
|
|
"""
|
|
|
|
return output.redistribute(device_mesh, [Shard(dim)])
|
|
|
|
|
|
@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
|
|
def make_output_replicate_1d(
|
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
|
) -> DTensor:
|
|
"""
|
|
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
|
|
|
|
Args:
|
|
output (:class:`DTensor`):
|
|
Output of module to be converted.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
Object needed to replicate the output and it needs to be a 1D ``device_mesh``
|
|
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
|
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
|
Default: ``None``
|
|
|
|
Return:
|
|
A :class:`DTensor` object made replicate.
|
|
"""
|
|
|
|
return output.redistribute(device_mesh, [Replicate()])
|
|
|
|
|
|
@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56]
|
|
def make_output_tensor(
|
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
|
|
|
|
Args:
|
|
output (:class:`DTensor`):
|
|
Output of module to be converted.
|
|
device_mesh (:class:`DeviceMesh`, optional):
|
|
Object which is needed to replicate the output and it needs to be
|
|
a 1D ``device_mesh`` and we will throw exceptions if a non-1D
|
|
``device_mesh`` is passed in. If no ``device_mesh`` is passed in,
|
|
we will reuse the one from output.
|
|
Default: ``None``
|
|
|
|
Return:
|
|
A :class:`torch.Tensor` object converted from output DTensor.
|
|
"""
|
|
|
|
return make_output_replicate_1d( # type: ignore[attr-defined]
|
|
output, device_mesh
|
|
).to_local() # type: ignore[call-arg]
|