mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177 Approved by: https://github.com/wanchaol ghstack dependencies: #111160, #111166, #111176
This commit is contained in:
parent
e942fddb83
commit
25a2845d78
|
|
@ -3,8 +3,13 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import distribute_tensor, DTensor
|
from torch.distributed._tensor import DTensor
|
||||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
from torch.distributed._tensor.placement_types import Replicate
|
||||||
|
from torch.distributed.tensor.parallel import (
|
||||||
|
ColwiseParallel,
|
||||||
|
parallelize_module,
|
||||||
|
RowwiseParallel,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
@ -46,7 +51,14 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
|
|
||||||
# Shard the parameter of local embedding and set it to sharded embedding.
|
# Shard the parameter of local embedding and set it to sharded embedding.
|
||||||
sharded_embedding.weight = torch.nn.Parameter(
|
sharded_embedding.weight = torch.nn.Parameter(
|
||||||
distribute_tensor(local_embedding.weight, device_mesh, [Shard(shard_dim)])
|
local_embedding.weight.clone().detach()
|
||||||
|
)
|
||||||
|
parallelize_module(
|
||||||
|
module=sharded_embedding,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
parallelize_plan=ColwiseParallel(output_layouts=Replicate())
|
||||||
|
if shard_dim == 1
|
||||||
|
else RowwiseParallel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run sharded computation
|
# Run sharded computation
|
||||||
|
|
@ -57,12 +69,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
target = torch.empty(
|
target = torch.empty(
|
||||||
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
|
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
|
||||||
).random_(0, 1)
|
).random_(0, 1)
|
||||||
placements = [Replicate()]
|
output = sharded_embedding(inp)
|
||||||
replicate_inp = DTensor.from_local(inp, device_mesh, placements)
|
|
||||||
sharded_output = sharded_embedding(replicate_inp)
|
|
||||||
output = sharded_output.redistribute(
|
|
||||||
sharded_output.device_mesh, [Replicate()]
|
|
||||||
).to_local()
|
|
||||||
|
|
||||||
# Run local computation
|
# Run local computation
|
||||||
local_output = local_embedding(inp)
|
local_output = local_embedding(inp)
|
||||||
|
|
@ -84,7 +91,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
attn_dup_loss.backward()
|
attn_dup_loss.backward()
|
||||||
|
|
||||||
gradient = sharded_embedding.weight.grad.redistribute(
|
gradient = sharded_embedding.weight.grad.redistribute(
|
||||||
sharded_output.device_mesh, [Replicate()]
|
device_mesh, [Replicate()]
|
||||||
).to_local()
|
).to_local()
|
||||||
|
|
||||||
local_grad = local_embedding.weight.grad
|
local_grad = local_embedding.weight.grad
|
||||||
|
|
@ -99,15 +106,13 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
sharded_output = torch.nn.functional.embedding(
|
sharded_output = torch.nn.functional.embedding(
|
||||||
replicate_inp,
|
DTensor.from_local(inp, device_mesh, [Replicate()]),
|
||||||
sharded_embedding.weight,
|
sharded_embedding.weight,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
local_output,
|
local_output,
|
||||||
sharded_output.redistribute(
|
sharded_output.redistribute(device_mesh, [Replicate()]).to_local(),
|
||||||
sharded_output.device_mesh, [Replicate()]
|
|
||||||
).to_local(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
|
|
@ -134,7 +139,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
def test_sharded_embedding_rowwise(self):
|
def test_sharded_embedding_rowwise(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
"DTensor does not support row-wise sharded embedding operation yet!",
|
"Only support ColwiseParallel when parallelizing Embedding now.",
|
||||||
):
|
):
|
||||||
self._run_embedding_op_test(0, [5, 12], 16, 22)
|
self._run_embedding_op_test(0, [5, 12], 16, 22)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||||
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
|
||||||
from torch.distributed.tensor.parallel.api import (
|
from torch.distributed.tensor.parallel.api import (
|
||||||
_parallelize_linear,
|
_parallelize_linear_like_module,
|
||||||
_parallelize_mlp,
|
_parallelize_mlp,
|
||||||
parallelize_module,
|
parallelize_module,
|
||||||
)
|
)
|
||||||
|
|
@ -264,7 +264,7 @@ class TensorParallelAPITests(DTensorTestBase):
|
||||||
|
|
||||||
# parallelize model_tp
|
# parallelize model_tp
|
||||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||||
model_tp = _parallelize_linear(model_tp, device_mesh, rowwise)
|
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, rowwise)
|
||||||
|
|
||||||
# let each rank generate unique local input
|
# let each rank generate unique local input
|
||||||
torch.manual_seed(self.rank)
|
torch.manual_seed(self.rank)
|
||||||
|
|
@ -283,7 +283,7 @@ class TensorParallelAPITests(DTensorTestBase):
|
||||||
|
|
||||||
# parallelize model_tp
|
# parallelize model_tp
|
||||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||||
model_tp = _parallelize_linear(model_tp, device_mesh, colwise)
|
model_tp = _parallelize_linear_like_module(model_tp, device_mesh, colwise)
|
||||||
|
|
||||||
self._compare_module(model, model_tp, inp_size)
|
self._compare_module(model, model_tp, inp_size)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,8 +75,10 @@ def parallelize_module( # type: ignore[return]
|
||||||
>>>
|
>>>
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
``PairwiseParallel`` comes with constraints for now. If you need finer
|
Currently, there are some constraints which makes it hard for complicated modules
|
||||||
granularity, you need to pass in a dict of module FQN and parallel style instead.
|
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
|
||||||
|
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
|
||||||
|
or submodule and there might be some code changes needed now.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
||||||
|
|
@ -104,7 +106,9 @@ def parallelize_module( # type: ignore[return]
|
||||||
if isinstance(parallelize_plan, ParallelStyle):
|
if isinstance(parallelize_plan, ParallelStyle):
|
||||||
# RowwiseParallel or ColwiseParallel
|
# RowwiseParallel or ColwiseParallel
|
||||||
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
|
if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
|
||||||
return _parallelize_linear(module, device_mesh, parallelize_plan)
|
return _parallelize_linear_like_module(
|
||||||
|
module, device_mesh, parallelize_plan
|
||||||
|
)
|
||||||
elif isinstance(parallelize_plan, PrepareModuleInput):
|
elif isinstance(parallelize_plan, PrepareModuleInput):
|
||||||
module.register_forward_pre_hook(lambda _, inputs: parallelize_plan._prepare_input(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
module.register_forward_pre_hook(lambda _, inputs: parallelize_plan._prepare_input(inputs, device_mesh)) # type: ignore[misc, call-arg]
|
||||||
return module
|
return module
|
||||||
|
|
@ -194,20 +198,29 @@ def _rowwise_parallelize_linear_fn(
|
||||||
module.register_parameter(name, dist_param)
|
module.register_parameter(name, dist_param)
|
||||||
|
|
||||||
|
|
||||||
def _colwise_parallelize_linear_fn(
|
def _colwise_parallelize_embedding_fn(
|
||||||
name: str,
|
name: str,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
|
) -> None:
|
||||||
|
_colwise_parallelize_linear_fn(name, module, device_mesh, sharding_dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _colwise_parallelize_linear_fn(
|
||||||
|
name: str,
|
||||||
|
module: nn.Module,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
sharding_dim: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
This function parallelizes the input :class:`nn.Linear` module in
|
This function parallelizes the input :class:`nn.Linear` or :class:`nn.Embedding`
|
||||||
:class:`ColwiseParallel` style.
|
module in :class:`ColwiseParallel` style.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str):
|
name (str):
|
||||||
Name of the input module.
|
Name of the input module.
|
||||||
module (:class:`nn.Module`):
|
module (:class:`nn.Module`):
|
||||||
The :class:`nn.Linear` module to be parallelized.
|
The :class:`nn.Linear` or :class:`nn.Embedding` module to be parallelized.
|
||||||
device_mesh (:class:`DeviceMesh`):
|
device_mesh (:class:`DeviceMesh`):
|
||||||
Object which describes the mesh topology of devices.
|
Object which describes the mesh topology of devices.
|
||||||
|
|
||||||
|
|
@ -217,12 +230,12 @@ def _colwise_parallelize_linear_fn(
|
||||||
|
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
dist_param = torch.nn.Parameter(
|
dist_param = torch.nn.Parameter(
|
||||||
distribute_tensor(param, device_mesh, [Shard(0)])
|
distribute_tensor(param, device_mesh, [Shard(sharding_dim)])
|
||||||
)
|
)
|
||||||
module.register_parameter(name, dist_param)
|
module.register_parameter(name, dist_param)
|
||||||
|
|
||||||
|
|
||||||
def _parallelize_linear(
|
def _parallelize_linear_like_module(
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
parallel_style: ParallelStyle = ColwiseParallel(),
|
parallel_style: ParallelStyle = ColwiseParallel(),
|
||||||
|
|
@ -258,9 +271,9 @@ def _parallelize_linear(
|
||||||
A :class:`nn.Module` object parallelized.
|
A :class:`nn.Module` object parallelized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(module, nn.Linear):
|
if not isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Expect a torch.nn.Linear module but received {type(module)}!"
|
f"Expect a torch.nn.Linear or a torch.nn.Embedding module but received {type(module)}!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(parallel_style, ParallelStyle):
|
if not isinstance(parallel_style, ParallelStyle):
|
||||||
|
|
@ -268,6 +281,13 @@ def _parallelize_linear(
|
||||||
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
|
"Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Embedding) and not isinstance(
|
||||||
|
parallel_style, (ColwiseParallel)
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only support ColwiseParallel when parallelizing Embedding now."
|
||||||
|
)
|
||||||
|
|
||||||
if device_mesh.ndim > 1:
|
if device_mesh.ndim > 1:
|
||||||
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
|
||||||
|
|
||||||
|
|
@ -283,7 +303,9 @@ def _parallelize_linear(
|
||||||
distribute_module(
|
distribute_module(
|
||||||
module,
|
module,
|
||||||
device_mesh,
|
device_mesh,
|
||||||
_colwise_parallelize_linear_fn,
|
_colwise_parallelize_embedding_fn
|
||||||
|
if isinstance(module, nn.Embedding)
|
||||||
|
else _colwise_parallelize_linear_fn,
|
||||||
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||||
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,10 @@ class ParallelStyle(ABC):
|
||||||
"""
|
"""
|
||||||
The parallel style user wants the module or submodule to be parallelized.
|
The parallel style user wants the module or submodule to be parallelized.
|
||||||
Users can extend this class to build their own parallel style with customized input/output preparations.
|
Users can extend this class to build their own parallel style with customized input/output preparations.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||||
|
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_prepare_input: _PrepareInputType
|
_prepare_input: _PrepareInputType
|
||||||
|
|
@ -71,11 +75,9 @@ class PairwiseParallel(ParallelStyle):
|
||||||
We assume both input and output need to be replicate DTensors.
|
We assume both input and output need to be replicate DTensors.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
PairwiseParallel does not support ``nn.MultiheadAttention``,
|
PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel.
|
||||||
``nn.Transformer`` well at this moment. One workaround is to apply
|
We recommend users to directly use latter instead and we are deprecating this
|
||||||
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
|
style and will remove it soon.
|
||||||
transformer. We recommend to use ``PairwiseParallel`` only
|
|
||||||
for even-number-layer MLP for now.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -113,11 +115,9 @@ class SequenceParallel(PairwiseParallel):
|
||||||
We assume both input and output need to be sharded DTensors.
|
We assume both input and output need to be sharded DTensors.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
SequenceParallel does not support ``nn.MultiheadAttention``,
|
SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel.
|
||||||
``nn.Transformer`` well at this moment. One workaround is to apply
|
We recommend users to directly use latter instead and we are deprecating this
|
||||||
``ColwiseParallel`` and ``RowwiseParallel`` to the components of
|
style and will remove it soon.
|
||||||
transformer. We recommend to use ``SequenceParallel`` only
|
|
||||||
for even-number-layer MLP for now.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -148,23 +148,8 @@ def make_input_shard_1d(
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``input_layouts`` instead.
|
||||||
Args:
|
|
||||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
||||||
Single tensor will be sharded on dimension ``dim``
|
|
||||||
over the 1-D :class:`DeviceMesh`.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
The 1-D device mesh where ``input`` will be sharded.
|
|
||||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
||||||
`input.device_mesh` will be used.
|
|
||||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
||||||
Default: ``None``
|
|
||||||
dim (int, optional): The sharding dimension of ``input`` tensor.
|
|
||||||
Default: 0
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.")
|
_deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.")
|
||||||
shard_spec = [Shard(dim)]
|
shard_spec = [Shard(dim)]
|
||||||
|
|
@ -185,21 +170,8 @@ def make_input_shard_1d_last_dim(
|
||||||
device_mesh: Optional[DeviceMesh] = None,
|
device_mesh: Optional[DeviceMesh] = None,
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``input_layouts`` instead.
|
||||||
Args:
|
|
||||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
||||||
This single tensor will be sharded on the last dimension
|
|
||||||
over the 1-D :class:`DeviceMesh`.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
The 1-D device mesh where ``input`` will be sharded.
|
|
||||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
||||||
`input.device_mesh` will be used.
|
|
||||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A :class:`DTensor` sharded on the last dimension over ``device_mesh``.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings(
|
_deprecate_warnings(
|
||||||
"make_input_shard_1d_last_dim", "Specify input_layouts instead."
|
"make_input_shard_1d_last_dim", "Specify input_layouts instead."
|
||||||
|
|
@ -213,22 +185,8 @@ def make_input_reshard_replicate(
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
To construct a Sharded DTensor from a tensor on different ranks
|
.. warning::
|
||||||
and then convert to a replicate DTensor.
|
This method was deprecated and please specify ``input_layouts`` instead.
|
||||||
|
|
||||||
Args:
|
|
||||||
input (:class:`torch.Tensor`):
|
|
||||||
The input tensor on each rank which consists of a global DTensor
|
|
||||||
sharded on dimension ``0`` over the 1-D :class:`DeviceMesh`
|
|
||||||
and then the sharded DTensor is converted to a replicate DTensor.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
The 1-D device mesh where ``input`` will be sharded.
|
|
||||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A :class:`DTensor` sharded on dimension ``0`` over ``device_mesh``
|
|
||||||
and then converted to replicate.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings(
|
_deprecate_warnings(
|
||||||
"make_input_reshard_replicate", "Specify input_layouts instead."
|
"make_input_reshard_replicate", "Specify input_layouts instead."
|
||||||
|
|
@ -244,20 +202,8 @@ def make_input_replicate_1d(
|
||||||
device_mesh: Optional[DeviceMesh] = None,
|
device_mesh: Optional[DeviceMesh] = None,
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``input_layouts`` instead.
|
||||||
Args:
|
|
||||||
input (Union[:class:`torch.Tensor`, :class:`DTensor`]):
|
|
||||||
This input tensor will be replicated over the 1-D :class:`DeviceMesh`.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
The 1-D device mesh where ``input`` will be replicated.
|
|
||||||
If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`,
|
|
||||||
``input.device_mesh`` will be used.
|
|
||||||
If :class:`DeviceMesh` is not 1-D, an exception will be thrown.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A :class:`DTensor` replicated over ``device_mesh``.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.")
|
_deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.")
|
||||||
replicate = [Replicate()]
|
replicate = [Replicate()]
|
||||||
|
|
@ -277,20 +223,8 @@ def make_output_shard_1d(
|
||||||
output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``output_layouts`` instead.
|
||||||
Args:
|
|
||||||
output (:class:`DTensor`):
|
|
||||||
Output of module to be converted.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
Object needed to shard the output and it needs to be a 1D ``device_mesh``
|
|
||||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
|
||||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
|
||||||
Default: ``None``
|
|
||||||
dim (int): Sharding dim for output. Default: 0
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A :class:`DTensor` object sharded on the given dim.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.")
|
_deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.")
|
||||||
return output.redistribute(device_mesh, [Shard(dim)])
|
return output.redistribute(device_mesh, [Shard(dim)])
|
||||||
|
|
@ -301,19 +235,8 @@ def make_output_replicate_1d(
|
||||||
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
||||||
) -> DTensor:
|
) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``output_layouts`` instead.
|
||||||
Args:
|
|
||||||
output (:class:`DTensor`):
|
|
||||||
Output of module to be converted.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
Object needed to replicate the output and it needs to be a 1D ``device_mesh``
|
|
||||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
|
||||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A :class:`DTensor` object made replicate.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.")
|
_deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.")
|
||||||
return output.redistribute(device_mesh, [Replicate()])
|
return output.redistribute(device_mesh, [Replicate()])
|
||||||
|
|
@ -324,20 +247,8 @@ def make_output_tensor(
|
||||||
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
output: DTensor, device_mesh: Optional[DeviceMesh] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``output_layouts`` instead.
|
||||||
Args:
|
|
||||||
output (:class:`DTensor`):
|
|
||||||
Output of module to be converted.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
Object which is needed to replicate the output and it needs to be
|
|
||||||
a 1D ``device_mesh`` and we will throw exceptions if a non-1D
|
|
||||||
``device_mesh`` is passed in. If no ``device_mesh`` is passed in,
|
|
||||||
we will reuse the one from output.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A :class:`torch.Tensor` object converted from output DTensor.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_output_tensor", "Specify output_layouts instead.")
|
_deprecate_warnings("make_output_tensor", "Specify output_layouts instead.")
|
||||||
return make_output_replicate_1d( # type: ignore[attr-defined, misc]
|
return make_output_replicate_1d( # type: ignore[attr-defined, misc]
|
||||||
|
|
@ -350,17 +261,8 @@ def make_sharded_output_tensor(
|
||||||
output: DTensor, _device_mesh: Optional[DeviceMesh] = None
|
output: DTensor, _device_mesh: Optional[DeviceMesh] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert sharded Output DTensor to torch.Tensor.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``output_layouts`` instead.
|
||||||
Args:
|
|
||||||
output (:class:`DTensor`):
|
|
||||||
Output of module to be converted.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A :class:`torch.Tensor` object converted from output DTensor.
|
|
||||||
|
|
||||||
``_device_mesh`` is not needed and is just kept to match with
|
|
||||||
the signature in its callsite in ``distribute_module``.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.")
|
_deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.")
|
||||||
return output.to_local() # type: ignore[call-arg]
|
return output.to_local() # type: ignore[call-arg]
|
||||||
|
|
@ -372,19 +274,8 @@ def make_output_reshard_tensor(
|
||||||
device_mesh: Optional[DeviceMesh] = None,
|
device_mesh: Optional[DeviceMesh] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert Output DTensor to a sharded DTensor and return the local tensor.
|
.. warning::
|
||||||
|
This method was deprecated and please specify ``output_layouts`` instead.
|
||||||
Args:
|
|
||||||
output (:class:`DTensor`):
|
|
||||||
Output of module to be converted.
|
|
||||||
device_mesh (:class:`DeviceMesh`, optional):
|
|
||||||
Object needed to shard the output and it needs to be a 1D ``device_mesh``
|
|
||||||
and we will throw exceptions if a non-1D ``device_mesh`` is passed in.
|
|
||||||
If no ``device_mesh`` is passed in, we will reuse the one from output.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A :class:`torch.Tensor` object converted from output DTensor.
|
|
||||||
"""
|
"""
|
||||||
_deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.")
|
_deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.")
|
||||||
return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc]
|
return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc]
|
||||||
|
|
@ -552,6 +443,41 @@ class RowwiseParallel(ParallelStyle):
|
||||||
"""
|
"""
|
||||||
Partitioning the row of a module.
|
Partitioning the row of a module.
|
||||||
We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`.
|
We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||||
|
The layout of input tensor(s) which DTensor will be created upon.
|
||||||
|
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||||
|
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||||
|
use_local_output (bool):
|
||||||
|
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
RowwiseParallel now only support ``nn.Linear``. Users can compose it with ColwiseParallel
|
||||||
|
to achieve the sharding of more complicated modules.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||||
|
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> # xdoctest: +SKIP(failing)
|
||||||
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
||||||
|
>>> ...
|
||||||
|
>>> parallelize_plan = {
|
||||||
|
>>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor
|
||||||
|
>>> # and we will return a replicate :class:`torch.Tensor` as output.
|
||||||
|
>>> ...
|
||||||
|
>>> }
|
||||||
|
>>> parallelize_module(
|
||||||
|
>>> module=block, # this can be a submodule or module
|
||||||
|
>>> ...,
|
||||||
|
>>> parallelize_plan=parallelize_plan,
|
||||||
|
>>> )
|
||||||
|
>>> ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -588,6 +514,41 @@ class ColwiseParallel(ParallelStyle):
|
||||||
"""
|
"""
|
||||||
Partitioning the column of a tensor or module.
|
Partitioning the column of a tensor or module.
|
||||||
We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`.
|
We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||||
|
The layout of input tensor(s) which DTensor will be created upon.
|
||||||
|
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||||
|
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||||
|
use_local_output (bool):
|
||||||
|
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
ColwiseParallel now only support ``nn.Linear`` and ``nn.Embedding``. Users can compose
|
||||||
|
it with RowwiseParallel to achieve the sharding of more complicated modules.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
|
||||||
|
remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> # xdoctest: +SKIP(failing)
|
||||||
|
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
|
||||||
|
>>> ...
|
||||||
|
>>> parallelize_plan = {
|
||||||
|
>>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor
|
||||||
|
>>> # and we will return a sharded :class:`torch.Tensor` as output.
|
||||||
|
>>> ...
|
||||||
|
>>> }
|
||||||
|
>>> parallelize_module(
|
||||||
|
>>> module=block, # this can be a submodule or module
|
||||||
|
>>> ...,
|
||||||
|
>>> parallelize_plan=parallelize_plan,
|
||||||
|
>>> )
|
||||||
|
>>> ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -599,6 +560,9 @@ class ColwiseParallel(ParallelStyle):
|
||||||
output_layouts=Shard(-1),
|
output_layouts=Shard(-1),
|
||||||
use_local_output=True,
|
use_local_output=True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple):
|
if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"ColwiseParallel only supports single input/output."
|
"ColwiseParallel only supports single input/output."
|
||||||
|
|
@ -647,7 +611,7 @@ class PrepareModuleInput(ParallelStyle):
|
||||||
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
output_layouts (Union[Placement, Tuple[Placement, ...]]):
|
||||||
The layout of input tensor(s) which created DTensor will be redistributed to.
|
The layout of input tensor(s) which created DTensor will be redistributed to.
|
||||||
use_local_output (bool):
|
use_local_output (bool):
|
||||||
Whether to convert the DTensor to local tensor.
|
Whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None.
|
None.
|
||||||
|
|
@ -686,7 +650,7 @@ class PrepareModuleOutput(ParallelStyle):
|
||||||
with ``output_layouts`` and ``use_local_output`` so that each output can be converted to
|
with ``output_layouts`` and ``use_local_output`` so that each output can be converted to
|
||||||
:class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor
|
:class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor
|
||||||
will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output``
|
will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output``
|
||||||
to decide whether to convert the DTensor to local tensor.
|
to decide whether to convert the DTensor to local :class:`torch.Tensor`.
|
||||||
|
|
||||||
When the output is not a :class:`DTensor`, if no layout is specified, it will be
|
When the output is not a :class:`DTensor`, if no layout is specified, it will be
|
||||||
a no-op. Otherwise, it will throw an error.
|
a no-op. Otherwise, it will throw an error.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user