[tp] fix torch compile regression (#111521)

The most recent refactor of TP
https://github.com/pytorch/pytorch/pull/111160 breaks torch compile
path, so reverting the behavior back by:
1. use the old default prepare_input/output
2. add the colwise/rowwise parallel test instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111521
Approved by: https://github.com/fduwjj
This commit is contained in:
Wanchao Liang 2023-10-18 22:21:11 -07:00 committed by PyTorch MergeBot
parent 894b9957c8
commit 03e28bde2e
3 changed files with 70 additions and 36 deletions

View File

@ -10,7 +10,12 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PairwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
@ -180,8 +185,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
def test_tp_compile_fullgraph(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model = MLPModule(self.device_type)
model = parallelize_module(model, mesh, PairwiseParallel())
model = SimpleModel(self.device_type)
model = parallelize_module(
model,
mesh,
parallelize_plan={
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
},
)
inp = torch.rand(20, 10, device=self.device_type)
out = model(inp)
compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True)

View File

@ -33,7 +33,6 @@ class TensorParallelStyleTest(DTensorTestBase):
input_local_tensor,
expected_local_tensor,
func,
tensor_input_only=False,
error_msgs="device_mesh is not passed nor can be inferred",
) -> None:
with self.assertRaisesRegex(RuntimeError, error_msgs):
@ -43,13 +42,12 @@ class TensorParallelStyleTest(DTensorTestBase):
# test 1: replicate local tensor
dtensor = func(input_local_tensor, device_mesh)
self.assertEqual(expected_local_tensor, dtensor.to_local())
if not tensor_input_only:
# test 2: replicate DTensor
dtensor = func(dtensor)
self.assertEqual(expected_local_tensor, dtensor.to_local())
# test 3: replicate DTensor with DeviceMesh passed
dtensor = func(dtensor, device_mesh)
self.assertEqual(expected_local_tensor, dtensor.to_local())
# test 2: replicate DTensor
dtensor = func(dtensor)
self.assertEqual(expected_local_tensor, dtensor.to_local())
# test 3: replicate DTensor with DeviceMesh passed
dtensor = func(dtensor, device_mesh)
self.assertEqual(expected_local_tensor, dtensor.to_local())
@with_comms
def test_make_input_replicate_1d(self):
@ -189,9 +187,9 @@ class TensorParallelStyleTest(DTensorTestBase):
dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)])
output = [dtensor]
with self.assertRaisesRegex(
RuntimeError,
"Tensor parallel module expects DTensor or tensor when layout specified but received"
f" {type(output)}!",
AssertionError,
"Expect output of Tensor Parallel to be a DTensor, but found"
f" {type(output)}.",
):
func(output, device_mesh)
@ -209,7 +207,6 @@ class TensorParallelStyleTest(DTensorTestBase):
tensor,
tensor,
rs._prepare_input,
error_msgs="No device mesh is currently active!",
)
# TODO: change output test
output, dtensor, device_mesh = self._test_prepare_output(
@ -235,7 +232,6 @@ class TensorParallelStyleTest(DTensorTestBase):
tensor,
tensor,
cs._prepare_input,
error_msgs="No device mesh is currently active!",
)
output, dtensor, device_mesh = self._test_prepare_output(
cs._prepare_output, [Shard(-1)]

View File

@ -494,19 +494,32 @@ class RowwiseParallel(ParallelStyle):
"RowwiseParallel only supports single input/output."
)
if _prepare_input is not None:
prepare_input_fn = _prepare_input
if input_layouts == Shard(-1):
prepare_input_fn = make_input_shard_1d_last_dim
else:
prepare_input_fn = _get_prepare_input(
input_layouts,
Shard(-1),
)
if _prepare_output is not None:
prepare_output_fn = _prepare_output
elif output_layouts == Replicate():
prepare_output_fn = make_output_tensor
else:
prepare_output_fn = _get_prepare_output(
output_layouts,
use_local_output,
)
super().__init__(
input_layouts=input_layouts,
output_layouts=output_layouts,
use_local_output=use_local_output,
_prepare_input=_prepare_input
if _prepare_input is not None
else _get_prepare_input(
input_layouts,
Shard(-1),
),
_prepare_output=_prepare_output
if _prepare_output is not None
else _get_prepare_output(output_layouts, use_local_output),
_prepare_input=prepare_input_fn,
_prepare_output=prepare_output_fn,
)
@ -568,21 +581,32 @@ class ColwiseParallel(ParallelStyle):
"ColwiseParallel only supports single input/output."
)
if _prepare_input is not None:
prepare_input_fn = _prepare_input
if input_layouts == Replicate():
prepare_input_fn = make_input_replicate_1d
else:
prepare_input_fn = _get_prepare_input(
input_layouts,
Replicate(),
)
if _prepare_output is not None:
prepare_output_fn = _prepare_output
elif output_layouts == Shard(-1):
prepare_output_fn = make_sharded_output_tensor
else:
prepare_output_fn = _get_prepare_output(
output_layouts,
use_local_output,
)
super().__init__(
input_layouts=input_layouts,
output_layouts=output_layouts,
use_local_output=use_local_output,
_prepare_input=_prepare_input
if _prepare_input is not None
else _get_prepare_input(
input_layouts,
[Replicate()] * len(input_layouts) # type: ignore[arg-type]
if isinstance(input_layouts, tuple)
else Replicate(),
),
_prepare_output=_prepare_output
if _prepare_output is not None
else _get_prepare_output(output_layouts, use_local_output),
_prepare_input=prepare_input_fn,
_prepare_output=prepare_output_fn,
)