mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177 Approved by: https://github.com/wanchaol ghstack dependencies: #111160, #111166, #111176
This commit is contained in:
parent
e942fddb83
commit
25a2845d78
|
|
@ -3,8 +3,13 @@
|
|||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.placement_types import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
|
@ -46,7 +51,14 @@ class TestEmbeddingOp(DTensorTestBase):
|
|||
|
||||
# Shard the parameter of local embedding and set it to sharded embedding.
|
||||
sharded_embedding.weight = torch.nn.Parameter(
|
||||
distribute_tensor(local_embedding.weight, device_mesh, [Shard(shard_dim)])
|
||||
local_embedding.weight.clone().detach()
|
||||
)
|
||||
parallelize_module(
|
||||
module=sharded_embedding,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=ColwiseParallel(output_layouts=Replicate())
|
||||
if shard_dim == 1
|
||||
else RowwiseParallel(),
|
||||
)
|
||||
|
||||
# Run sharded computation
|
||||
|
|
@ -57,12 +69,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
|||
target = torch.empty(
|
||||
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
|
||||
).random_(0, 1)
|
||||
placements = [Replicate()]
|
||||
replicate_inp = DTensor.from_local(inp, device_mesh, placements)
|
||||
sharded_output = sharded_embedding(replicate_inp)
|
||||
output = sharded_output.redistribute(
|
||||
sharded_output.device_mesh, [Replicate()]
|
||||
).to_local()
|
||||
output = sharded_embedding(inp)
|
||||
|
||||
# Run local computation
|
||||
local_output = local_embedding(inp)
|
||||
|
|
@ -84,7 +91,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
|||
attn_dup_loss.backward()
|
||||
|
||||
gradient = sharded_embedding.weight.grad.redistribute(
|
||||
sharded_output.device_mesh, [Replicate()]
|
||||
device_mesh, [Replicate()]
|
||||
).to_local()
|
||||
|
||||
local_grad = local_embedding.weight.grad
|
||||
|
|
@ -99,15 +106,13 @@ class TestEmbeddingOp(DTensorTestBase):
|
|||
**kwargs,
|
||||
)
|
||||
sharded_output = torch.nn.functional.embedding(
|
||||
replicate_inp,
|
||||
DTensor.from_local(inp, device_mesh, [Replicate()]),
|
||||
sharded_embedding.weight,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(
|
||||
local_output,
|
||||
sharded_output.redistribute(
|
||||
sharded_output.device_mesh, [Replicate()]
|
||||
).to_local(),
|
||||
sharded_output.redistribute(device_mesh, [Replicate()]).to_local(),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
|
|
@ -134,7 +139,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
|||
def test_sharded_embedding_rowwise(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"DTensor does not support row-wise sharded embedding operation yet!",
|
||||
"Only support ColwiseParallel when parallelizing Embedding now.",
|
||||
):
|
||||
self._run_embedding_op_test(0, [5, 12], 16, 22)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
||||
from torch.distributed.tensor.parallel.api import (
|
||||
_parallelize_linear,
|
||||
_parallelize_linear_like_module,
|
||||
_parallelize_mlp,
|
||||
parallelize_module,
|
||||
)
|
||||
|
|
@ -264,7 +264,7 @@ class TensorParallelAPITests(DTensorTestBase):
|
|||
|
||||
# parallelize model_tp
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
model_tp = _parallelize_linear(model_tp, device_mesh, rowwise)
|
||||
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, rowwise)
|
||||
|
||||
# let each rank generate unique local input
|
||||
torch.manual_seed(self.rank)
|
||||
|
|
@ -283,7 +283,7 @@ class TensorParallelAPITests(DTensorTestBase):
|
|||
|
||||
# parallelize model_tp
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
model_tp = _parallelize_linear(model_tp, device_mesh, colwise)
|
||||
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, colwise)
|
||||
|
||||
self._compare_module(model, model_tp, inp_size)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,8 +75,10 @@ def parallelize_module( # type: ignore[return]
|
|||
>>>
|
||||
|
||||
.. warning::
|
||||
``PairwiseParallel`` comes with constraints for now. If you need finer
|
||||
granularity, you need to pass in a dict of module FQN and parallel style instead.
|
||||
Currently, there are some constraints which makes it hard for complicated modules
|
||||
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
|
||||
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
|
||||
or submodule and there might be some code changes needed now.
|
||||
"""
|
||||
|
||||
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
||||
|
|
@ -104,7 +106,9 @@ def parallelize_module( # type: ignore[return]
|
|||
if isinstance(parallelize_plan, ParallelStyle):
|
||||
# RowwiseParallel or ColwiseParallel
|
||||
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
|
||||
return _parallelize_linear(module, device_mesh, parallelize_plan)
|
||||
return _parallelize_linear_like_module(
|
||||
module, device_mesh, parallelize_plan
|
||||
)
|
||||
elif isinstance(parallelize_plan, PrepareModuleInput):
|
||||
module.register_forward_pre_hook(lambda _, inputs: parallelize_plan._prepare_input(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
||||
return module
|
||||
|
|
@ -194,20 +198,29 @@ def _rowwise_parallelize_linear_fn(
|
|||
module.register_parameter(name, dist_param)
|
||||
|
||||
|
||||
def _colwise_parallelize_linear_fn(
|
||||
def _colwise_parallelize_embedding_fn(
|
||||
name: str,
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
) -> None:
|
||||
_colwise_parallelize_linear_fn(name, module, device_mesh, sharding_dim=1)
|
||||
|
||||
|
||||
def _colwise_parallelize_linear_fn(
|
||||
name: str,
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
sharding_dim: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
This function parallelizes the input :class:`nn.Linear` module in
|
||||
:class:`ColwiseParallel` style.
|
||||
This function parallelizes the input :class:`nn.Linear` or :class:`nn.Embedding`
|
||||
module in :class:`ColwiseParallel` style.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
Name of the input module.
|
||||
module (:class:`nn.Module`):
|
||||
The :class:`nn.Linear` module to be parallelized.
|
||||
The :class:`nn.Linear` or :class:`nn.Embedding` module to be parallelized.
|
||||
device_mesh (:class:`DeviceMesh`):
|
||||
Object which describes the mesh topology of devices.
|
||||
|
||||
|
|
@ -217,12 +230,12 @@ def _colwise_parallelize_linear_fn(
|
|||
|
||||
for name, param in module.named_parameters():
|
||||
dist_param = torch.nn.Parameter(
|
||||
distribute_tensor(param, device_mesh, [Shard(0)])
|
||||
distribute_tensor(param, device_mesh, [Shard(sharding_dim)])
|
||||
)
|
||||
module.register_parameter(name, dist_param)
|
||||
|
||||
|
||||
def _parallelize_linear(
|
||||
def _parallelize_linear_like_module(
|
||||
module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
parallel_style: ParallelStyle = ColwiseParallel(),
|
||||
|
|
@ -258,9 +271,9 @@ def _parallelize_linear(
|
|||
A :class:`nn.Module` object parallelized.
|
||||
"""
|
||||
|
||||
if not isinstance(module, nn.Linear):
|
||||
if not isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
raise RuntimeError(
|
||||
f"Expect a torch.nn.Linear module but received {type(module)}!"
|
||||
f"Expect a torch.nn.Linear or a torch.nn.Embedding module but received {type(module)}!"
|
||||
)
|
||||
|
||||
if not isinstance(parallel_style, ParallelStyle):
|
||||
|
|
@ -268,6 +281,13 @@ def _parallelize_linear(
|
|||
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Embedding) and not isinstance(
|
||||
parallel_style, (ColwiseParallel)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Only support ColwiseParallel when parallelizing Embedding now."
|
||||
)
|
||||
|
||||
if device_mesh.ndim > 1:
|
||||
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
||||
|
||||
|
|
@ -283,7 +303,9 @@ def _parallelize_linear(
|
|||
distribute_module(
|
||||
module,
|
||||
device_mesh,
|
||||
_colwise_parallelize_linear_fn,
|
||||
_colwise_parallelize_embedding_fn
|
||||
if isinstance(module, nn.Embedding)
|
||||
else _colwise_parallelize_linear_fn,
|
||||
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ class ParallelStyle(ABC):
|
|||
"""
|
||||
The parallel style user wants the module or submodule to be parallelized.
|
||||
Users can extend this class to build their own parallel style with customized input/output preparations.
|
||||
|
||||
.. warning::
|
||||
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||
"""
|
||||
|
||||
_prepare_input: _PrepareInputType
|
||||
|
|
@ -71,11 +75,9 @@ class PairwiseParallel(ParallelStyle):
|
|||
We assume both input and output need to be replicate DTensors.
|
||||
|
||||
.. warning::
|
||||
PairwiseParallel does not support ``nn.MultiheadAttention``,
|
||||
``nn.Transformer`` well at this moment. One workaround is to apply
|
||||
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
|
||||
transformer. We recommend to use ``PairwiseParallel`` only
|
||||
for even-number-layer MLP for now.
|
||||
PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel.
|
||||
We recommend users to directly use latter instead and we are deprecating this
|
||||
style and will remove it soon.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -113,11 +115,9 @@ class SequenceParallel(PairwiseParallel):
|
|||
We assume both input and output need to be sharded DTensors.
|
||||
|
||||
.. warning::
|
||||
SequenceParallel does not support ``nn.MultiheadAttention``,
|
||||
``nn.Transformer`` well at this moment. One workaround is to apply
|
||||
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
|
||||
transformer. We recommend to use ``SequenceParallel`` only
|
||||
for even-number-layer MLP for now.
|
||||
SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel.
|
||||
We recommend users to directly use latter instead and we are deprecating this
|
||||
style and will remove it soon.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -148,23 +148,8 @@ def make_input_shard_1d(
|
|||
dim: int = 0,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle.
|
||||
|
||||
Args:
|
||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
||||
Single tensor will be sharded on dimension ``dim``
|
||||
over the 1-D :class:`DeviceMesh`.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
The 1-D device mesh where ``input`` will be sharded.
|
||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
||||
`input.device_mesh` will be used.
|
||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
||||
Default: ``None``
|
||||
dim (int, optional): The sharding dimension of ``input`` tensor.
|
||||
Default: 0
|
||||
|
||||
Returns:
|
||||
A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``input_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.")
|
||||
shard_spec = [Shard(dim)]
|
||||
|
|
@ -185,21 +170,8 @@ def make_input_shard_1d_last_dim(
|
|||
device_mesh: Optional[DeviceMesh] = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1.
|
||||
|
||||
Args:
|
||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
||||
This single tensor will be sharded on the last dimension
|
||||
over the 1-D :class:`DeviceMesh`.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
The 1-D device mesh where ``input`` will be sharded.
|
||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
||||
`input.device_mesh` will be used.
|
||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
A :class:`DTensor` sharded on the last dimension over ``device_mesh``.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``input_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings(
|
||||
"make_input_shard_1d_last_dim", "Specify input_layouts instead."
|
||||
|
|
@ -213,22 +185,8 @@ def make_input_reshard_replicate(
|
|||
device_mesh: DeviceMesh,
|
||||
) -> DTensor:
|
||||
"""
|
||||
To construct a Sharded DTensor from a tensor on different ranks
|
||||
and then convert to a replicate DTensor.
|
||||
|
||||
Args:
|
||||
input (:class:`torch.Tensor`):
|
||||
The input tensor on each rank which consists of a global DTensor
|
||||
sharded on dimension ``0`` over the 1-D :class:`DeviceMesh`
|
||||
and then the sharded DTensor is converted to a replicate DTensor.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
The 1-D device mesh where ``input`` will be sharded.
|
||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
A :class:`DTensor` sharded on dimension ``0`` over ``device_mesh``
|
||||
and then converted to replicate.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``input_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings(
|
||||
"make_input_reshard_replicate", "Specify input_layouts instead."
|
||||
|
|
@ -244,20 +202,8 @@ def make_input_replicate_1d(
|
|||
device_mesh: Optional[DeviceMesh] = None,
|
||||
) -> DTensor:
|
||||
"""
|
||||
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
|
||||
|
||||
Args:
|
||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
||||
This input tensor will be replicated over the 1-D :class:`DeviceMesh`.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
The 1-D device mesh where ``input`` will be replicated.
|
||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
||||
``input.device_mesh`` will be used.
|
||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
A :class:`DTensor` replicated over ``device_mesh``.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``input_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.")
|
||||
replicate = [Replicate()]
|
||||
|
|
@ -277,20 +223,8 @@ def make_output_shard_1d(
|
|||
output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
|
||||
) -> DTensor:
|
||||
"""
|
||||
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
|
||||
|
||||
Args:
|
||||
output (:class:`DTensor`):
|
||||
Output of module to be converted.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
Object needed to shard the output and it needs to be a 1D ``device_mesh``
|
||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
||||
Default: ``None``
|
||||
dim (int): Sharding dim for output. Default: 0
|
||||
|
||||
Return:
|
||||
A :class:`DTensor` object sharded on the given dim.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``output_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.")
|
||||
return output.redistribute(device_mesh, [Shard(dim)])
|
||||
|
|
@ -301,19 +235,8 @@ def make_output_replicate_1d(
|
|||
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
||||
) -> DTensor:
|
||||
"""
|
||||
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
|
||||
|
||||
Args:
|
||||
output (:class:`DTensor`):
|
||||
Output of module to be converted.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
Object needed to replicate the output and it needs to be a 1D ``device_mesh``
|
||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
||||
Default: ``None``
|
||||
|
||||
Return:
|
||||
A :class:`DTensor` object made replicate.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``output_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.")
|
||||
return output.redistribute(device_mesh, [Replicate()])
|
||||
|
|
@ -324,20 +247,8 @@ def make_output_tensor(
|
|||
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
|
||||
|
||||
Args:
|
||||
output (:class:`DTensor`):
|
||||
Output of module to be converted.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
Object which is needed to replicate the output and it needs to be
|
||||
a 1D ``device_mesh`` and we will throw exceptions if a non-1D
|
||||
``device_mesh`` is passed in. If no ``device_mesh`` is passed in,
|
||||
we will reuse the one from output.
|
||||
Default: ``None``
|
||||
|
||||
Return:
|
||||
A :class:`torch.Tensor` object converted from output DTensor.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``output_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_output_tensor", "Specify output_layouts instead.")
|
||||
return make_output_replicate_1d( # type: ignore[attr-defined, misc]
|
||||
|
|
@ -350,17 +261,8 @@ def make_sharded_output_tensor(
|
|||
output: DTensor, _device_mesh: Optional[DeviceMesh] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert sharded Output DTensor to torch.Tensor.
|
||||
|
||||
Args:
|
||||
output (:class:`DTensor`):
|
||||
Output of module to be converted.
|
||||
|
||||
Return:
|
||||
A :class:`torch.Tensor` object converted from output DTensor.
|
||||
|
||||
``_device_mesh`` is not needed and is just kept to match with
|
||||
the signature in its callsite in ``distribute_module``.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``output_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.")
|
||||
return output.to_local() # type: ignore[call-arg]
|
||||
|
|
@ -372,19 +274,8 @@ def make_output_reshard_tensor(
|
|||
device_mesh: Optional[DeviceMesh] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Output DTensor to a sharded DTensor and return the local tensor.
|
||||
|
||||
Args:
|
||||
output (:class:`DTensor`):
|
||||
Output of module to be converted.
|
||||
device_mesh (:class:`DeviceMesh`, optional):
|
||||
Object needed to shard the output and it needs to be a 1D ``device_mesh``
|
||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
||||
Default: ``None``
|
||||
|
||||
Return:
|
||||
A :class:`torch.Tensor` object converted from output DTensor.
|
||||
.. warning::
|
||||
This method was deprecated and please specify ``output_layouts`` instead.
|
||||
"""
|
||||
_deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.")
|
||||
return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc]
|
||||
|
|
@ -552,6 +443,41 @@ class RowwiseParallel(ParallelStyle):
|
|||
"""
|
||||
Partitioning the row of a module.
|
||||
We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`.
|
||||
|
||||
Args:
|
||||
input_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||
The layout of input tensor(s) which DTensor will be created upon.
|
||||
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||
use_local_output (bool):
|
||||
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
.. warning::
|
||||
RowwiseParallel now only support ``nn.Linear``. Users can compose it with ColwiseParallel
|
||||
to achieve the sharding of more complicated modules.
|
||||
|
||||
.. warning::
|
||||
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP(failing)
|
||||
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
||||
>>> ...
|
||||
>>> parallelize_plan = {
|
||||
>>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor
|
||||
>>> # and we will return a replicate :class:`torch.Tensor` as output.
|
||||
>>> ...
|
||||
>>> }
|
||||
>>> parallelize_module(
|
||||
>>> module=block, # this can be a submodule or module
|
||||
>>> ...,
|
||||
>>> parallelize_plan=parallelize_plan,
|
||||
>>> )
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -588,6 +514,41 @@ class ColwiseParallel(ParallelStyle):
|
|||
"""
|
||||
Partitioning the column of a tensor or module.
|
||||
We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`.
|
||||
|
||||
Args:
|
||||
input_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||
The layout of input tensor(s) which DTensor will be created upon.
|
||||
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||
use_local_output (bool):
|
||||
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
.. warning::
|
||||
ColwiseParallel now only support ``nn.Linear`` and ``nn.Embedding``. Users can compose
|
||||
it with RowwiseParallel to achieve the sharding of more complicated modules.
|
||||
|
||||
.. warning::
|
||||
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP(failing)
|
||||
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
||||
>>> ...
|
||||
>>> parallelize_plan = {
|
||||
>>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor
|
||||
>>> # and we will return a sharded :class:`torch.Tensor` as output.
|
||||
>>> ...
|
||||
>>> }
|
||||
>>> parallelize_module(
|
||||
>>> module=block, # this can be a submodule or module
|
||||
>>> ...,
|
||||
>>> parallelize_plan=parallelize_plan,
|
||||
>>> )
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -599,6 +560,9 @@ class ColwiseParallel(ParallelStyle):
|
|||
output_layouts=Shard(-1),
|
||||
use_local_output=True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
"""
|
||||
if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple):
|
||||
raise NotImplementedError(
|
||||
"ColwiseParallel only supports single input/output."
|
||||
|
|
@ -647,7 +611,7 @@ class PrepareModuleInput(ParallelStyle):
|
|||
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||
use_local_output (bool):
|
||||
Whether to convert the DTensor to local tensor.
|
||||
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
|
@ -686,7 +650,7 @@ class PrepareModuleOutput(ParallelStyle):
|
|||
with ``output_layouts`` and ``use_local_output`` so that each output can be converted to
|
||||
:class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor
|
||||
will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output``
|
||||
to decide whether to convert the DTensor to local tensor.
|
||||
to decide whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||
|
||||
When the output is not a :class:`DTensor`, if no layout is specified, it will be
|
||||
a no-op. Otherwise, it will throw an error.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user