diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 91d5b0530ac..1bdf61877a1 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -3,8 +3,13 @@ import sys import torch -from torch.distributed._tensor import distribute_tensor, DTensor -from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed._tensor import DTensor +from torch.distributed._tensor.placement_types import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -46,7 +51,14 @@ class TestEmbeddingOp(DTensorTestBase): # Shard the parameter of local embedding and set it to sharded embedding. sharded_embedding.weight = torch.nn.Parameter( - distribute_tensor(local_embedding.weight, device_mesh, [Shard(shard_dim)]) + local_embedding.weight.clone().detach() + ) + parallelize_module( + module=sharded_embedding, + device_mesh=device_mesh, + parallelize_plan=ColwiseParallel(output_layouts=Replicate()) + if shard_dim == 1 + else RowwiseParallel(), ) # Run sharded computation @@ -57,12 +69,7 @@ class TestEmbeddingOp(DTensorTestBase): target = torch.empty( *inp.size(), embedding_dim, dtype=torch.float, device=self.device_type ).random_(0, 1) - placements = [Replicate()] - replicate_inp = DTensor.from_local(inp, device_mesh, placements) - sharded_output = sharded_embedding(replicate_inp) - output = sharded_output.redistribute( - sharded_output.device_mesh, [Replicate()] - ).to_local() + output = sharded_embedding(inp) # Run local computation local_output = local_embedding(inp) @@ -84,7 +91,7 @@ class TestEmbeddingOp(DTensorTestBase): attn_dup_loss.backward() gradient = sharded_embedding.weight.grad.redistribute( - sharded_output.device_mesh, [Replicate()] + device_mesh, [Replicate()] ).to_local() local_grad = local_embedding.weight.grad @@ -99,15 +106,13 @@ class TestEmbeddingOp(DTensorTestBase): **kwargs, ) sharded_output = torch.nn.functional.embedding( - replicate_inp, + DTensor.from_local(inp, device_mesh, [Replicate()]), sharded_embedding.weight, **kwargs, ) self.assertEqual( local_output, - sharded_output.redistribute( - sharded_output.device_mesh, [Replicate()] - ).to_local(), + sharded_output.redistribute(device_mesh, [Replicate()]).to_local(), ) @with_comms @@ -134,7 +139,7 @@ class TestEmbeddingOp(DTensorTestBase): def test_sharded_embedding_rowwise(self): with self.assertRaisesRegex( NotImplementedError, - "DTensor does not support row-wise sharded embedding operation yet!", + "Only support ColwiseParallel when parallelizing Embedding now.", ): self._run_embedding_op_test(0, [5, 12], 16, 22) diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 800c4f85291..91fb2b50662 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -5,7 +5,7 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh from torch.distributed.tensor.parallel.api import ( - _parallelize_linear, + _parallelize_linear_like_module, _parallelize_mlp, parallelize_module, ) @@ -264,7 +264,7 @@ class TensorParallelAPITests(DTensorTestBase): # parallelize model_tp device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - model_tp = _parallelize_linear(model_tp, device_mesh, rowwise) + model_tp = _parallelize_linear_like_module(model_tp, device_mesh, rowwise) # let each rank generate unique local input torch.manual_seed(self.rank) @@ -283,7 +283,7 @@ class TensorParallelAPITests(DTensorTestBase): # parallelize model_tp device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - model_tp = _parallelize_linear(model_tp, device_mesh, colwise) + model_tp = _parallelize_linear_like_module(model_tp, device_mesh, colwise) self._compare_module(model, model_tp, inp_size) diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 35ab0f8b4d1..1f4667a4f90 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -75,8 +75,10 @@ def parallelize_module( # type: ignore[return] >>> .. warning:: - ``PairwiseParallel`` comes with constraints for now. If you need finer - granularity, you need to pass in a dict of module FQN and parallel style instead. + Currently, there are some constraints which makes it hard for complicated modules + like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism. + We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter + or submodule and there might be some code changes needed now. """ torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") @@ -104,7 +106,9 @@ def parallelize_module( # type: ignore[return] if isinstance(parallelize_plan, ParallelStyle): # RowwiseParallel or ColwiseParallel if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)): - return _parallelize_linear(module, device_mesh, parallelize_plan) + return _parallelize_linear_like_module( + module, device_mesh, parallelize_plan + ) elif isinstance(parallelize_plan, PrepareModuleInput): module.register_forward_pre_hook(lambda _, inputs: parallelize_plan._prepare_input(inputs, device_mesh)) # type: ignore[misc, call-arg] return module @@ -194,20 +198,29 @@ def _rowwise_parallelize_linear_fn( module.register_parameter(name, dist_param) -def _colwise_parallelize_linear_fn( +def _colwise_parallelize_embedding_fn( name: str, module: nn.Module, device_mesh: DeviceMesh, +) -> None: + _colwise_parallelize_linear_fn(name, module, device_mesh, sharding_dim=1) + + +def _colwise_parallelize_linear_fn( + name: str, + module: nn.Module, + device_mesh: DeviceMesh, + sharding_dim: int = 0, ) -> None: """ - This function parallelizes the input :class:`nn.Linear` module in - :class:`ColwiseParallel` style. + This function parallelizes the input :class:`nn.Linear` or :class:`nn.Embedding` + module in :class:`ColwiseParallel` style. Args: name (str): Name of the input module. module (:class:`nn.Module`): - The :class:`nn.Linear` module to be parallelized. + The :class:`nn.Linear` or :class:`nn.Embedding` module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices. @@ -217,12 +230,12 @@ def _colwise_parallelize_linear_fn( for name, param in module.named_parameters(): dist_param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) + distribute_tensor(param, device_mesh, [Shard(sharding_dim)]) ) module.register_parameter(name, dist_param) -def _parallelize_linear( +def _parallelize_linear_like_module( module: nn.Module, device_mesh: DeviceMesh, parallel_style: ParallelStyle = ColwiseParallel(), @@ -258,9 +271,9 @@ def _parallelize_linear( A :class:`nn.Module` object parallelized. """ - if not isinstance(module, nn.Linear): + if not isinstance(module, (nn.Linear, nn.Embedding)): raise RuntimeError( - f"Expect a torch.nn.Linear module but received {type(module)}!" + f"Expect a torch.nn.Linear or a torch.nn.Embedding module but received {type(module)}!" ) if not isinstance(parallel_style, ParallelStyle): @@ -268,6 +281,13 @@ def _parallelize_linear( "Expect a ParallelStyle object but received" f" {type(parallel_style)}!" ) + if isinstance(module, nn.Embedding) and not isinstance( + parallel_style, (ColwiseParallel) + ): + raise NotImplementedError( + "Only support ColwiseParallel when parallelizing Embedding now." + ) + if device_mesh.ndim > 1: device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) @@ -283,7 +303,9 @@ def _parallelize_linear( distribute_module( module, device_mesh, - _colwise_parallelize_linear_fn, + _colwise_parallelize_embedding_fn + if isinstance(module, nn.Embedding) + else _colwise_parallelize_linear_fn, input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6] output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6] ) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 1df02931d3b..8a379a8d686 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -39,6 +39,10 @@ class ParallelStyle(ABC): """ The parallel style user wants the module or submodule to be parallelized. Users can extend this class to build their own parallel style with customized input/output preparations. + + .. warning:: + ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will + remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead. """ _prepare_input: _PrepareInputType @@ -71,11 +75,9 @@ class PairwiseParallel(ParallelStyle): We assume both input and output need to be replicate DTensors. .. warning:: - PairwiseParallel does not support ``nn.MultiheadAttention``, - ``nn.Transformer`` well at this moment. One workaround is to apply - ``ColwiseParallel`` and ``RowwiseParallel`` to the components of - transformer. We recommend to use ``PairwiseParallel`` only - for even-number-layer MLP for now. + PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel. + We recommend users to directly use latter instead and we are deprecating this + style and will remove it soon. """ def __init__( @@ -113,11 +115,9 @@ class SequenceParallel(PairwiseParallel): We assume both input and output need to be sharded DTensors. .. warning:: - SequenceParallel does not support ``nn.MultiheadAttention``, - ``nn.Transformer`` well at this moment. One workaround is to apply - ``ColwiseParallel`` and ``RowwiseParallel`` to the components of - transformer. We recommend to use ``SequenceParallel`` only - for even-number-layer MLP for now. + SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel. + We recommend users to directly use latter instead and we are deprecating this + style and will remove it soon. """ def __init__( @@ -148,23 +148,8 @@ def make_input_shard_1d( dim: int = 0, ) -> DTensor: """ - Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle. - - Args: - input (Union[:class:`torch.Tensor`, :class:`DTensor`]): - Single tensor will be sharded on dimension ``dim`` - over the 1-D :class:`DeviceMesh`. - device_mesh (:class:`DeviceMesh`, optional): - The 1-D device mesh where ``input`` will be sharded. - If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, - `input.device_mesh` will be used. - If :class:`DeviceMesh` is not 1-D, an exception will be thrown. - Default: ``None`` - dim (int, optional): The sharding dimension of ``input`` tensor. - Default: 0 - - Returns: - A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``. + .. warning:: + This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.") shard_spec = [Shard(dim)] @@ -185,21 +170,8 @@ def make_input_shard_1d_last_dim( device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: """ - Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1. - - Args: - input (Union[:class:`torch.Tensor`, :class:`DTensor`]): - This single tensor will be sharded on the last dimension - over the 1-D :class:`DeviceMesh`. - device_mesh (:class:`DeviceMesh`, optional): - The 1-D device mesh where ``input`` will be sharded. - If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, - `input.device_mesh` will be used. - If :class:`DeviceMesh` is not 1-D, an exception will be thrown. - Default: ``None`` - - Returns: - A :class:`DTensor` sharded on the last dimension over ``device_mesh``. + .. warning:: + This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings( "make_input_shard_1d_last_dim", "Specify input_layouts instead." @@ -213,22 +185,8 @@ def make_input_reshard_replicate( device_mesh: DeviceMesh, ) -> DTensor: """ - To construct a Sharded DTensor from a tensor on different ranks - and then convert to a replicate DTensor. - - Args: - input (:class:`torch.Tensor`): - The input tensor on each rank which consists of a global DTensor - sharded on dimension ``0`` over the 1-D :class:`DeviceMesh` - and then the sharded DTensor is converted to a replicate DTensor. - device_mesh (:class:`DeviceMesh`, optional): - The 1-D device mesh where ``input`` will be sharded. - If :class:`DeviceMesh` is not 1-D, an exception will be thrown. - Default: ``None`` - - Returns: - A :class:`DTensor` sharded on dimension ``0`` over ``device_mesh`` - and then converted to replicate. + .. warning:: + This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings( "make_input_reshard_replicate", "Specify input_layouts instead." @@ -244,20 +202,8 @@ def make_input_replicate_1d( device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: """ - Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle. - - Args: - input (Union[:class:`torch.Tensor`, :class:`DTensor`]): - This input tensor will be replicated over the 1-D :class:`DeviceMesh`. - device_mesh (:class:`DeviceMesh`, optional): - The 1-D device mesh where ``input`` will be replicated. - If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, - ``input.device_mesh`` will be used. - If :class:`DeviceMesh` is not 1-D, an exception will be thrown. - Default: ``None`` - - Returns: - A :class:`DTensor` replicated over ``device_mesh``. + .. warning:: + This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.") replicate = [Replicate()] @@ -277,20 +223,8 @@ def make_output_shard_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0 ) -> DTensor: """ - Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle. - - Args: - output (:class:`DTensor`): - Output of module to be converted. - device_mesh (:class:`DeviceMesh`, optional): - Object needed to shard the output and it needs to be a 1D ``device_mesh`` - and we will throw exceptions if a non-1D ``device_mesh`` is passed in. - If no ``device_mesh`` is passed in, we will reuse the one from output. - Default: ``None`` - dim (int): Sharding dim for output. Default: 0 - - Return: - A :class:`DTensor` object sharded on the given dim. + .. warning:: + This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.") return output.redistribute(device_mesh, [Shard(dim)]) @@ -301,19 +235,8 @@ def make_output_replicate_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> DTensor: """ - Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle. - - Args: - output (:class:`DTensor`): - Output of module to be converted. - device_mesh (:class:`DeviceMesh`, optional): - Object needed to replicate the output and it needs to be a 1D ``device_mesh`` - and we will throw exceptions if a non-1D ``device_mesh`` is passed in. - If no ``device_mesh`` is passed in, we will reuse the one from output. - Default: ``None`` - - Return: - A :class:`DTensor` object made replicate. + .. warning:: + This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.") return output.redistribute(device_mesh, [Replicate()]) @@ -324,20 +247,8 @@ def make_output_tensor( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> torch.Tensor: """ - Convert Output DTensor to a replicated DTensor first and then convert it to Tensor. - - Args: - output (:class:`DTensor`): - Output of module to be converted. - device_mesh (:class:`DeviceMesh`, optional): - Object which is needed to replicate the output and it needs to be - a 1D ``device_mesh`` and we will throw exceptions if a non-1D - ``device_mesh`` is passed in. If no ``device_mesh`` is passed in, - we will reuse the one from output. - Default: ``None`` - - Return: - A :class:`torch.Tensor` object converted from output DTensor. + .. warning:: + This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_tensor", "Specify output_layouts instead.") return make_output_replicate_1d( # type: ignore[attr-defined, misc] @@ -350,17 +261,8 @@ def make_sharded_output_tensor( output: DTensor, _device_mesh: Optional[DeviceMesh] = None ) -> torch.Tensor: """ - Convert sharded Output DTensor to torch.Tensor. - - Args: - output (:class:`DTensor`): - Output of module to be converted. - - Return: - A :class:`torch.Tensor` object converted from output DTensor. - - ``_device_mesh`` is not needed and is just kept to match with - the signature in its callsite in ``distribute_module``. + .. warning:: + This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.") return output.to_local() # type: ignore[call-arg] @@ -372,19 +274,8 @@ def make_output_reshard_tensor( device_mesh: Optional[DeviceMesh] = None, ) -> torch.Tensor: """ - Convert Output DTensor to a sharded DTensor and return the local tensor. - - Args: - output (:class:`DTensor`): - Output of module to be converted. - device_mesh (:class:`DeviceMesh`, optional): - Object needed to shard the output and it needs to be a 1D ``device_mesh`` - and we will throw exceptions if a non-1D ``device_mesh`` is passed in. - If no ``device_mesh`` is passed in, we will reuse the one from output. - Default: ``None`` - - Return: - A :class:`torch.Tensor` object converted from output DTensor. + .. warning:: + This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.") return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc] @@ -552,6 +443,41 @@ class RowwiseParallel(ParallelStyle): """ Partitioning the row of a module. We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`. + + Args: + input_layouts (Union[Placement, Tuple[Placement, ...]]): + The layout of input tensor(s) which DTensor will be created upon. + output_layouts (Union[Placement, Tuple[Placement, ...]]): + The layout of input tensor(s) which created DTensor will be redistributed to. + use_local_output (bool): + Whether to convert the DTensor to local :class:`torch.Tensor`. + + Returns: + None. + + .. warning:: + RowwiseParallel now only support ``nn.Linear``. Users can compose it with ColwiseParallel + to achieve the sharding of more complicated modules. + + .. warning:: + ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will + remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput + >>> ... + >>> parallelize_plan = { + >>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor + >>> # and we will return a replicate :class:`torch.Tensor` as output. + >>> ... + >>> } + >>> parallelize_module( + >>> module=block, # this can be a submodule or module + >>> ..., + >>> parallelize_plan=parallelize_plan, + >>> ) + >>> ... """ def __init__( @@ -588,6 +514,41 @@ class ColwiseParallel(ParallelStyle): """ Partitioning the column of a tensor or module. We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`. + + Args: + input_layouts (Union[Placement, Tuple[Placement, ...]]): + The layout of input tensor(s) which DTensor will be created upon. + output_layouts (Union[Placement, Tuple[Placement, ...]]): + The layout of input tensor(s) which created DTensor will be redistributed to. + use_local_output (bool): + Whether to convert the DTensor to local :class:`torch.Tensor`. + + Returns: + None. + + .. warning:: + ColwiseParallel now only support ``nn.Linear`` and ``nn.Embedding``. Users can compose + it with RowwiseParallel to achieve the sharding of more complicated modules. + + .. warning:: + ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will + remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput + >>> ... + >>> parallelize_plan = { + >>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor + >>> # and we will return a sharded :class:`torch.Tensor` as output. + >>> ... + >>> } + >>> parallelize_module( + >>> module=block, # this can be a submodule or module + >>> ..., + >>> parallelize_plan=parallelize_plan, + >>> ) + >>> ... """ def __init__( @@ -599,6 +560,9 @@ class ColwiseParallel(ParallelStyle): output_layouts=Shard(-1), use_local_output=True, ) -> None: + """ + + """ if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple): raise NotImplementedError( "ColwiseParallel only supports single input/output." @@ -647,7 +611,7 @@ class PrepareModuleInput(ParallelStyle): output_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which created DTensor will be redistributed to. use_local_output (bool): - Whether to convert the DTensor to local tensor. + Whether to convert the DTensor to local :class:`torch.Tensor`. Returns: None. @@ -686,7 +650,7 @@ class PrepareModuleOutput(ParallelStyle): with ``output_layouts`` and ``use_local_output`` so that each output can be converted to :class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output`` - to decide whether to convert the DTensor to local tensor. + to decide whether to convert the DTensor to local :class:`torch.Tensor`. When the output is not a :class:`DTensor`, if no layout is specified, it will be a no-op. Otherwise, it will throw an error.