diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index f0a7c877d1e..4e875241078 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -10,7 +10,12 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PairwiseParallel, + parallelize_module, + RowwiseParallel, +) from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests @@ -180,8 +185,17 @@ class TestDTensorCompileE2E(DTensorTestBase): def test_tp_compile_fullgraph(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) - model = MLPModule(self.device_type) - model = parallelize_module(model, mesh, PairwiseParallel()) + model = SimpleModel(self.device_type) + model = parallelize_module( + model, + mesh, + parallelize_plan={ + "mlp_0.net1": ColwiseParallel(), + "mlp_0.net2": RowwiseParallel(), + "mlp_1.net1": ColwiseParallel(), + "mlp_1.net2": RowwiseParallel(), + }, + ) inp = torch.rand(20, 10, device=self.device_type) out = model(inp) compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True) diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index 76f78e547aa..b9976d26c9c 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -33,7 +33,6 @@ class TensorParallelStyleTest(DTensorTestBase): input_local_tensor, expected_local_tensor, func, - tensor_input_only=False, error_msgs="device_mesh is not passed nor can be inferred", ) -> None: with self.assertRaisesRegex(RuntimeError, error_msgs): @@ -43,13 +42,12 @@ class TensorParallelStyleTest(DTensorTestBase): # test 1: replicate local tensor dtensor = func(input_local_tensor, device_mesh) self.assertEqual(expected_local_tensor, dtensor.to_local()) - if not tensor_input_only: - # test 2: replicate DTensor - dtensor = func(dtensor) - self.assertEqual(expected_local_tensor, dtensor.to_local()) - # test 3: replicate DTensor with DeviceMesh passed - dtensor = func(dtensor, device_mesh) - self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 2: replicate DTensor + dtensor = func(dtensor) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 3: replicate DTensor with DeviceMesh passed + dtensor = func(dtensor, device_mesh) + self.assertEqual(expected_local_tensor, dtensor.to_local()) @with_comms def test_make_input_replicate_1d(self): @@ -189,9 +187,9 @@ class TensorParallelStyleTest(DTensorTestBase): dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) output = [dtensor] with self.assertRaisesRegex( - RuntimeError, - "Tensor parallel module expects DTensor or tensor when layout specified but received" - f" {type(output)}!", + AssertionError, + "Expect output of Tensor Parallel to be a DTensor, but found" + f" {type(output)}.", ): func(output, device_mesh) @@ -209,7 +207,6 @@ class TensorParallelStyleTest(DTensorTestBase): tensor, tensor, rs._prepare_input, - error_msgs="No device mesh is currently active!", ) # TODO: change output test output, dtensor, device_mesh = self._test_prepare_output( @@ -235,7 +232,6 @@ class TensorParallelStyleTest(DTensorTestBase): tensor, tensor, cs._prepare_input, - error_msgs="No device mesh is currently active!", ) output, dtensor, device_mesh = self._test_prepare_output( cs._prepare_output, [Shard(-1)] diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index eea31bb5758..4a47a5f7002 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -494,19 +494,32 @@ class RowwiseParallel(ParallelStyle): "RowwiseParallel only supports single input/output." ) + if _prepare_input is not None: + prepare_input_fn = _prepare_input + if input_layouts == Shard(-1): + prepare_input_fn = make_input_shard_1d_last_dim + else: + prepare_input_fn = _get_prepare_input( + input_layouts, + Shard(-1), + ) + + if _prepare_output is not None: + prepare_output_fn = _prepare_output + elif output_layouts == Replicate(): + prepare_output_fn = make_output_tensor + else: + prepare_output_fn = _get_prepare_output( + output_layouts, + use_local_output, + ) + super().__init__( input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, - _prepare_input=_prepare_input - if _prepare_input is not None - else _get_prepare_input( - input_layouts, - Shard(-1), - ), - _prepare_output=_prepare_output - if _prepare_output is not None - else _get_prepare_output(output_layouts, use_local_output), + _prepare_input=prepare_input_fn, + _prepare_output=prepare_output_fn, ) @@ -568,21 +581,32 @@ class ColwiseParallel(ParallelStyle): "ColwiseParallel only supports single input/output." ) + if _prepare_input is not None: + prepare_input_fn = _prepare_input + if input_layouts == Replicate(): + prepare_input_fn = make_input_replicate_1d + else: + prepare_input_fn = _get_prepare_input( + input_layouts, + Replicate(), + ) + + if _prepare_output is not None: + prepare_output_fn = _prepare_output + elif output_layouts == Shard(-1): + prepare_output_fn = make_sharded_output_tensor + else: + prepare_output_fn = _get_prepare_output( + output_layouts, + use_local_output, + ) + super().__init__( input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, - _prepare_input=_prepare_input - if _prepare_input is not None - else _get_prepare_input( - input_layouts, - [Replicate()] * len(input_layouts) # type: ignore[arg-type] - if isinstance(input_layouts, tuple) - else Replicate(), - ), - _prepare_output=_prepare_output - if _prepare_output is not None - else _get_prepare_output(output_layouts, use_local_output), + _prepare_input=prepare_input_fn, + _prepare_output=prepare_output_fn, )