[TP] Enable embedding sharding in TP API (#111177)

We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176
This commit is contained in:
fduwjj 2023-10-14 23:21:29 -07:00 committed by PyTorch MergeBot
parent e942fddb83
commit 25a2845d78
4 changed files with 160 additions and 169 deletions

View File

@ -3,8 +3,13 @@
import sys
import torch
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -46,7 +51,14 @@ class TestEmbeddingOp(DTensorTestBase):
# Shard the parameter of local embedding and set it to sharded embedding.
sharded_embedding.weight = torch.nn.Parameter(
distribute_tensor(local_embedding.weight, device_mesh, [Shard(shard_dim)])
local_embedding.weight.clone().detach()
)
parallelize_module(
module=sharded_embedding,
device_mesh=device_mesh,
parallelize_plan=ColwiseParallel(output_layouts=Replicate())
if shard_dim == 1
else RowwiseParallel(),
)
# Run sharded computation
@ -57,12 +69,7 @@ class TestEmbeddingOp(DTensorTestBase):
target = torch.empty(
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
).random_(0, 1)
placements = [Replicate()]
replicate_inp = DTensor.from_local(inp, device_mesh, placements)
sharded_output = sharded_embedding(replicate_inp)
output = sharded_output.redistribute(
sharded_output.device_mesh, [Replicate()]
).to_local()
output = sharded_embedding(inp)
# Run local computation
local_output = local_embedding(inp)
@ -84,7 +91,7 @@ class TestEmbeddingOp(DTensorTestBase):
attn_dup_loss.backward()
gradient = sharded_embedding.weight.grad.redistribute(
sharded_output.device_mesh, [Replicate()]
device_mesh, [Replicate()]
).to_local()
local_grad = local_embedding.weight.grad
@ -99,15 +106,13 @@ class TestEmbeddingOp(DTensorTestBase):
**kwargs,
)
sharded_output = torch.nn.functional.embedding(
replicate_inp,
DTensor.from_local(inp, device_mesh, [Replicate()]),
sharded_embedding.weight,
**kwargs,
)
self.assertEqual(
local_output,
sharded_output.redistribute(
sharded_output.device_mesh, [Replicate()]
).to_local(),
sharded_output.redistribute(device_mesh, [Replicate()]).to_local(),
)
@with_comms
@ -134,7 +139,7 @@ class TestEmbeddingOp(DTensorTestBase):
def test_sharded_embedding_rowwise(self):
with self.assertRaisesRegex(
NotImplementedError,
"DTensor does not support row-wise sharded embedding operation yet!",
"Only support ColwiseParallel when parallelizing Embedding now.",
):
self._run_embedding_op_test(0, [5, 12], 16, 22)

View File

@ -5,7 +5,7 @@ import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
from torch.distributed.tensor.parallel.api import (
_parallelize_linear,
_parallelize_linear_like_module,
_parallelize_mlp,
parallelize_module,
)
@ -264,7 +264,7 @@ class TensorParallelAPITests(DTensorTestBase):
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
model_tp = _parallelize_linear(model_tp, device_mesh, rowwise)
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, rowwise)
# let each rank generate unique local input
torch.manual_seed(self.rank)
@ -283,7 +283,7 @@ class TensorParallelAPITests(DTensorTestBase):
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
model_tp = _parallelize_linear(model_tp, device_mesh, colwise)
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, colwise)
self._compare_module(model, model_tp, inp_size)

View File

@ -75,8 +75,10 @@ def parallelize_module( # type: ignore[return]
>>>
.. warning::
``PairwiseParallel`` comes with constraints for now. If you need finer
granularity, you need to pass in a dict of module FQN and parallel style instead.
Currently, there are some constraints which makes it hard for complicated modules
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
or submodule and there might be some code changes needed now.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
@ -104,7 +106,9 @@ def parallelize_module( # type: ignore[return]
if isinstance(parallelize_plan, ParallelStyle):
# RowwiseParallel or ColwiseParallel
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
return _parallelize_linear(module, device_mesh, parallelize_plan)
return _parallelize_linear_like_module(
module, device_mesh, parallelize_plan
)
elif isinstance(parallelize_plan, PrepareModuleInput):
module.register_forward_pre_hook(lambda _, inputs: parallelize_plan._prepare_input(inputs, device_mesh)) # type: ignore[misc, call-arg]
return module
@ -194,20 +198,29 @@ def _rowwise_parallelize_linear_fn(
module.register_parameter(name, dist_param)
def _colwise_parallelize_linear_fn(
def _colwise_parallelize_embedding_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
) -> None:
_colwise_parallelize_linear_fn(name, module, device_mesh, sharding_dim=1)
def _colwise_parallelize_linear_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
sharding_dim: int = 0,
) -> None:
"""
This function parallelizes the input :class:`nn.Linear` module in
:class:`ColwiseParallel` style.
This function parallelizes the input :class:`nn.Linear` or :class:`nn.Embedding`
module in :class:`ColwiseParallel` style.
Args:
name (str):
Name of the input module.
module (:class:`nn.Module`):
The :class:`nn.Linear` module to be parallelized.
The :class:`nn.Linear` or :class:`nn.Embedding` module to be parallelized.
device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology of devices.
@ -217,12 +230,12 @@ def _colwise_parallelize_linear_fn(
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Shard(0)])
distribute_tensor(param, device_mesh, [Shard(sharding_dim)])
)
module.register_parameter(name, dist_param)
def _parallelize_linear(
def _parallelize_linear_like_module(
module: nn.Module,
device_mesh: DeviceMesh,
parallel_style: ParallelStyle = ColwiseParallel(),
@ -258,9 +271,9 @@ def _parallelize_linear(
A :class:`nn.Module` object parallelized.
"""
if not isinstance(module, nn.Linear):
if not isinstance(module, (nn.Linear, nn.Embedding)):
raise RuntimeError(
f"Expect a torch.nn.Linear module but received {type(module)}!"
f"Expect a torch.nn.Linear or a torch.nn.Embedding module but received {type(module)}!"
)
if not isinstance(parallel_style, ParallelStyle):
@ -268,6 +281,13 @@ def _parallelize_linear(
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
)
if isinstance(module, nn.Embedding) and not isinstance(
parallel_style, (ColwiseParallel)
):
raise NotImplementedError(
"Only support ColwiseParallel when parallelizing Embedding now."
)
if device_mesh.ndim > 1:
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
@ -283,7 +303,9 @@ def _parallelize_linear(
distribute_module(
module,
device_mesh,
_colwise_parallelize_linear_fn,
_colwise_parallelize_embedding_fn
if isinstance(module, nn.Embedding)
else _colwise_parallelize_linear_fn,
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
)

View File

@ -39,6 +39,10 @@ class ParallelStyle(ABC):
"""
The parallel style user wants the module or submodule to be parallelized.
Users can extend this class to build their own parallel style with customized input/output preparations.
.. warning::
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
"""
_prepare_input: _PrepareInputType
@ -71,11 +75,9 @@ class PairwiseParallel(ParallelStyle):
We assume both input and output need to be replicate DTensors.
.. warning::
PairwiseParallel does not support ``nn.MultiheadAttention``,
``nn.Transformer`` well at this moment. One workaround is to apply
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
transformer. We recommend to use ``PairwiseParallel`` only
for even-number-layer MLP for now.
PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel.
We recommend users to directly use latter instead and we are deprecating this
style and will remove it soon.
"""
def __init__(
@ -113,11 +115,9 @@ class SequenceParallel(PairwiseParallel):
We assume both input and output need to be sharded DTensors.
.. warning::
SequenceParallel does not support ``nn.MultiheadAttention``,
``nn.Transformer`` well at this moment. One workaround is to apply
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
transformer. We recommend to use ``SequenceParallel`` only
for even-number-layer MLP for now.
SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel.
We recommend users to directly use latter instead and we are deprecating this
style and will remove it soon.
"""
def __init__(
@ -148,23 +148,8 @@ def make_input_shard_1d(
dim: int = 0,
) -> DTensor:
"""
Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle.
Args:
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
Single tensor will be sharded on dimension ``dim``
over the 1-D :class:`DeviceMesh`.
device_mesh (:class:`DeviceMesh`, optional):
The 1-D device mesh where ``input`` will be sharded.
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
`input.device_mesh` will be used.
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
Default: ``None``
dim (int, optional): The sharding dimension of ``input`` tensor.
Default: 0
Returns:
A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
.. warning::
This method was deprecated and please specify ``input_layouts`` instead.
"""
_deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.")
shard_spec = [Shard(dim)]
@ -185,21 +170,8 @@ def make_input_shard_1d_last_dim(
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
"""
Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1.
Args:
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
This single tensor will be sharded on the last dimension
over the 1-D :class:`DeviceMesh`.
device_mesh (:class:`DeviceMesh`, optional):
The 1-D device mesh where ``input`` will be sharded.
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
`input.device_mesh` will be used.
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
Default: ``None``
Returns:
A :class:`DTensor` sharded on the last dimension over ``device_mesh``.
.. warning::
This method was deprecated and please specify ``input_layouts`` instead.
"""
_deprecate_warnings(
"make_input_shard_1d_last_dim", "Specify input_layouts instead."
@ -213,22 +185,8 @@ def make_input_reshard_replicate(
device_mesh: DeviceMesh,
) -> DTensor:
"""
To construct a Sharded DTensor from a tensor on different ranks
and then convert to a replicate DTensor.
Args:
input (:class:`torch.Tensor`):
The input tensor on each rank which consists of a global DTensor
sharded on dimension ``0`` over the 1-D :class:`DeviceMesh`
and then the sharded DTensor is converted to a replicate DTensor.
device_mesh (:class:`DeviceMesh`, optional):
The 1-D device mesh where ``input`` will be sharded.
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
Default: ``None``
Returns:
A :class:`DTensor` sharded on dimension ``0`` over ``device_mesh``
and then converted to replicate.
.. warning::
This method was deprecated and please specify ``input_layouts`` instead.
"""
_deprecate_warnings(
"make_input_reshard_replicate", "Specify input_layouts instead."
@ -244,20 +202,8 @@ def make_input_replicate_1d(
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
"""
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
Args:
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
This input tensor will be replicated over the 1-D :class:`DeviceMesh`.
device_mesh (:class:`DeviceMesh`, optional):
The 1-D device mesh where ``input`` will be replicated.
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
``input.device_mesh`` will be used.
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
Default: ``None``
Returns:
A :class:`DTensor` replicated over ``device_mesh``.
.. warning::
This method was deprecated and please specify ``input_layouts`` instead.
"""
_deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.")
replicate = [Replicate()]
@ -277,20 +223,8 @@ def make_output_shard_1d(
output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
) -> DTensor:
"""
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
Args:
output (:class:`DTensor`):
Output of module to be converted.
device_mesh (:class:`DeviceMesh`, optional):
Object needed to shard the output and it needs to be a 1D ``device_mesh``
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
If no ``device_mesh`` is passed in, we will reuse the one from output.
Default: ``None``
dim (int): Sharding dim for output. Default: 0
Return:
A :class:`DTensor` object sharded on the given dim.
.. warning::
This method was deprecated and please specify ``output_layouts`` instead.
"""
_deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.")
return output.redistribute(device_mesh, [Shard(dim)])
@ -301,19 +235,8 @@ def make_output_replicate_1d(
output: DTensor, device_mesh: Optional[DeviceMesh] = None
) -> DTensor:
"""
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
Args:
output (:class:`DTensor`):
Output of module to be converted.
device_mesh (:class:`DeviceMesh`, optional):
Object needed to replicate the output and it needs to be a 1D ``device_mesh``
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
If no ``device_mesh`` is passed in, we will reuse the one from output.
Default: ``None``
Return:
A :class:`DTensor` object made replicate.
.. warning::
This method was deprecated and please specify ``output_layouts`` instead.
"""
_deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.")
return output.redistribute(device_mesh, [Replicate()])
@ -324,20 +247,8 @@ def make_output_tensor(
output: DTensor, device_mesh: Optional[DeviceMesh] = None
) -> torch.Tensor:
"""
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
Args:
output (:class:`DTensor`):
Output of module to be converted.
device_mesh (:class:`DeviceMesh`, optional):
Object which is needed to replicate the output and it needs to be
a 1D ``device_mesh`` and we will throw exceptions if a non-1D
``device_mesh`` is passed in. If no ``device_mesh`` is passed in,
we will reuse the one from output.
Default: ``None``
Return:
A :class:`torch.Tensor` object converted from output DTensor.
.. warning::
This method was deprecated and please specify ``output_layouts`` instead.
"""
_deprecate_warnings("make_output_tensor", "Specify output_layouts instead.")
return make_output_replicate_1d( # type: ignore[attr-defined, misc]
@ -350,17 +261,8 @@ def make_sharded_output_tensor(
output: DTensor, _device_mesh: Optional[DeviceMesh] = None
) -> torch.Tensor:
"""
Convert sharded Output DTensor to torch.Tensor.
Args:
output (:class:`DTensor`):
Output of module to be converted.
Return:
A :class:`torch.Tensor` object converted from output DTensor.
``_device_mesh`` is not needed and is just kept to match with
the signature in its callsite in ``distribute_module``.
.. warning::
This method was deprecated and please specify ``output_layouts`` instead.
"""
_deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.")
return output.to_local() # type: ignore[call-arg]
@ -372,19 +274,8 @@ def make_output_reshard_tensor(
device_mesh: Optional[DeviceMesh] = None,
) -> torch.Tensor:
"""
Convert Output DTensor to a sharded DTensor and return the local tensor.
Args:
output (:class:`DTensor`):
Output of module to be converted.
device_mesh (:class:`DeviceMesh`, optional):
Object needed to shard the output and it needs to be a 1D ``device_mesh``
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
If no ``device_mesh`` is passed in, we will reuse the one from output.
Default: ``None``
Return:
A :class:`torch.Tensor` object converted from output DTensor.
.. warning::
This method was deprecated and please specify ``output_layouts`` instead.
"""
_deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.")
return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc]
@ -552,6 +443,41 @@ class RowwiseParallel(ParallelStyle):
"""
Partitioning the row of a module.
We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`.
Args:
input_layouts (Union[Placement, Tuple[Placement, ...]]):
The layout of input tensor(s) which DTensor will be created upon.
output_layouts (Union[Placement, Tuple[Placement, ...]]):
The layout of input tensor(s) which created DTensor will be redistributed to.
use_local_output (bool):
Whether to convert the DTensor to local :class:`torch.Tensor`.
Returns:
None.
.. warning::
RowwiseParallel now only support ``nn.Linear``. Users can compose it with ColwiseParallel
to achieve the sharding of more complicated modules.
.. warning::
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> parallelize_plan = {
>>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor
>>> # and we will return a replicate :class:`torch.Tensor` as output.
>>> ...
>>> }
>>> parallelize_module(
>>> module=block, # this can be a submodule or module
>>> ...,
>>> parallelize_plan=parallelize_plan,
>>> )
>>> ...
"""
def __init__(
@ -588,6 +514,41 @@ class ColwiseParallel(ParallelStyle):
"""
Partitioning the column of a tensor or module.
We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`.
Args:
input_layouts (Union[Placement, Tuple[Placement, ...]]):
The layout of input tensor(s) which DTensor will be created upon.
output_layouts (Union[Placement, Tuple[Placement, ...]]):
The layout of input tensor(s) which created DTensor will be redistributed to.
use_local_output (bool):
Whether to convert the DTensor to local :class:`torch.Tensor`.
Returns:
None.
.. warning::
ColwiseParallel now only support ``nn.Linear`` and ``nn.Embedding``. Users can compose
it with RowwiseParallel to achieve the sharding of more complicated modules.
.. warning::
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> parallelize_plan = {
>>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor
>>> # and we will return a sharded :class:`torch.Tensor` as output.
>>> ...
>>> }
>>> parallelize_module(
>>> module=block, # this can be a submodule or module
>>> ...,
>>> parallelize_plan=parallelize_plan,
>>> )
>>> ...
"""
def __init__(
@ -599,6 +560,9 @@ class ColwiseParallel(ParallelStyle):
output_layouts=Shard(-1),
use_local_output=True,
) -> None:
"""
"""
if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple):
raise NotImplementedError(
"ColwiseParallel only supports single input/output."
@ -647,7 +611,7 @@ class PrepareModuleInput(ParallelStyle):
output_layouts (Union[Placement, Tuple[Placement, ...]]):
The layout of input tensor(s) which created DTensor will be redistributed to.
use_local_output (bool):
Whether to convert the DTensor to local tensor.
Whether to convert the DTensor to local :class:`torch.Tensor`.
Returns:
None.
@ -686,7 +650,7 @@ class PrepareModuleOutput(ParallelStyle):
with ``output_layouts`` and ``use_local_output`` so that each output can be converted to
:class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor
will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output``
to decide whether to convert the DTensor to local tensor.
to decide whether to convert the DTensor to local :class:`torch.Tensor`.
When the output is not a :class:`DTensor`, if no layout is specified, it will be
a no-op. Otherwise, it will throw an error.