mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[tp] fix torch compile regression (#111521)
The most recent refactor of TP https://github.com/pytorch/pytorch/pull/111160 breaks torch compile path, so reverting the behavior back by: 1. use the old default prepare_input/output 2. add the colwise/rowwise parallel test instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/111521 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
894b9957c8
commit
03e28bde2e
|
|
@ -10,7 +10,12 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
PairwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
@ -180,8 +185,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
def test_tp_compile_fullgraph(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
model = MLPModule(self.device_type)
|
||||
model = parallelize_module(model, mesh, PairwiseParallel())
|
||||
model = SimpleModel(self.device_type)
|
||||
model = parallelize_module(
|
||||
model,
|
||||
mesh,
|
||||
parallelize_plan={
|
||||
"mlp_0.net1": ColwiseParallel(),
|
||||
"mlp_0.net2": RowwiseParallel(),
|
||||
"mlp_1.net1": ColwiseParallel(),
|
||||
"mlp_1.net2": RowwiseParallel(),
|
||||
},
|
||||
)
|
||||
inp = torch.rand(20, 10, device=self.device_type)
|
||||
out = model(inp)
|
||||
compiled_mod = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class TensorParallelStyleTest(DTensorTestBase):
|
|||
input_local_tensor,
|
||||
expected_local_tensor,
|
||||
func,
|
||||
tensor_input_only=False,
|
||||
error_msgs="device_mesh is not passed nor can be inferred",
|
||||
) -> None:
|
||||
with self.assertRaisesRegex(RuntimeError, error_msgs):
|
||||
|
|
@ -43,13 +42,12 @@ class TensorParallelStyleTest(DTensorTestBase):
|
|||
# test 1: replicate local tensor
|
||||
dtensor = func(input_local_tensor, device_mesh)
|
||||
self.assertEqual(expected_local_tensor, dtensor.to_local())
|
||||
if not tensor_input_only:
|
||||
# test 2: replicate DTensor
|
||||
dtensor = func(dtensor)
|
||||
self.assertEqual(expected_local_tensor, dtensor.to_local())
|
||||
# test 3: replicate DTensor with DeviceMesh passed
|
||||
dtensor = func(dtensor, device_mesh)
|
||||
self.assertEqual(expected_local_tensor, dtensor.to_local())
|
||||
# test 2: replicate DTensor
|
||||
dtensor = func(dtensor)
|
||||
self.assertEqual(expected_local_tensor, dtensor.to_local())
|
||||
# test 3: replicate DTensor with DeviceMesh passed
|
||||
dtensor = func(dtensor, device_mesh)
|
||||
self.assertEqual(expected_local_tensor, dtensor.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_make_input_replicate_1d(self):
|
||||
|
|
@ -189,9 +187,9 @@ class TensorParallelStyleTest(DTensorTestBase):
|
|||
dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)])
|
||||
output = [dtensor]
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Tensor parallel module expects DTensor or tensor when layout specified but received"
|
||||
f" {type(output)}!",
|
||||
AssertionError,
|
||||
"Expect output of Tensor Parallel to be a DTensor, but found"
|
||||
f" {type(output)}.",
|
||||
):
|
||||
func(output, device_mesh)
|
||||
|
||||
|
|
@ -209,7 +207,6 @@ class TensorParallelStyleTest(DTensorTestBase):
|
|||
tensor,
|
||||
tensor,
|
||||
rs._prepare_input,
|
||||
error_msgs="No device mesh is currently active!",
|
||||
)
|
||||
# TODO: change output test
|
||||
output, dtensor, device_mesh = self._test_prepare_output(
|
||||
|
|
@ -235,7 +232,6 @@ class TensorParallelStyleTest(DTensorTestBase):
|
|||
tensor,
|
||||
tensor,
|
||||
cs._prepare_input,
|
||||
error_msgs="No device mesh is currently active!",
|
||||
)
|
||||
output, dtensor, device_mesh = self._test_prepare_output(
|
||||
cs._prepare_output, [Shard(-1)]
|
||||
|
|
|
|||
|
|
@ -494,19 +494,32 @@ class RowwiseParallel(ParallelStyle):
|
|||
"RowwiseParallel only supports single input/output."
|
||||
)
|
||||
|
||||
if _prepare_input is not None:
|
||||
prepare_input_fn = _prepare_input
|
||||
if input_layouts == Shard(-1):
|
||||
prepare_input_fn = make_input_shard_1d_last_dim
|
||||
else:
|
||||
prepare_input_fn = _get_prepare_input(
|
||||
input_layouts,
|
||||
Shard(-1),
|
||||
)
|
||||
|
||||
if _prepare_output is not None:
|
||||
prepare_output_fn = _prepare_output
|
||||
elif output_layouts == Replicate():
|
||||
prepare_output_fn = make_output_tensor
|
||||
else:
|
||||
prepare_output_fn = _get_prepare_output(
|
||||
output_layouts,
|
||||
use_local_output,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
input_layouts=input_layouts,
|
||||
output_layouts=output_layouts,
|
||||
use_local_output=use_local_output,
|
||||
_prepare_input=_prepare_input
|
||||
if _prepare_input is not None
|
||||
else _get_prepare_input(
|
||||
input_layouts,
|
||||
Shard(-1),
|
||||
),
|
||||
_prepare_output=_prepare_output
|
||||
if _prepare_output is not None
|
||||
else _get_prepare_output(output_layouts, use_local_output),
|
||||
_prepare_input=prepare_input_fn,
|
||||
_prepare_output=prepare_output_fn,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -568,21 +581,32 @@ class ColwiseParallel(ParallelStyle):
|
|||
"ColwiseParallel only supports single input/output."
|
||||
)
|
||||
|
||||
if _prepare_input is not None:
|
||||
prepare_input_fn = _prepare_input
|
||||
if input_layouts == Replicate():
|
||||
prepare_input_fn = make_input_replicate_1d
|
||||
else:
|
||||
prepare_input_fn = _get_prepare_input(
|
||||
input_layouts,
|
||||
Replicate(),
|
||||
)
|
||||
|
||||
if _prepare_output is not None:
|
||||
prepare_output_fn = _prepare_output
|
||||
elif output_layouts == Shard(-1):
|
||||
prepare_output_fn = make_sharded_output_tensor
|
||||
else:
|
||||
prepare_output_fn = _get_prepare_output(
|
||||
output_layouts,
|
||||
use_local_output,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
input_layouts=input_layouts,
|
||||
output_layouts=output_layouts,
|
||||
use_local_output=use_local_output,
|
||||
_prepare_input=_prepare_input
|
||||
if _prepare_input is not None
|
||||
else _get_prepare_input(
|
||||
input_layouts,
|
||||
[Replicate()] * len(input_layouts) # type: ignore[arg-type]
|
||||
if isinstance(input_layouts, tuple)
|
||||
else Replicate(),
|
||||
),
|
||||
_prepare_output=_prepare_output
|
||||
if _prepare_output is not None
|
||||
else _get_prepare_output(output_layouts, use_local_output),
|
||||
_prepare_input=prepare_input_fn,
|
||||
_prepare_output=prepare_output_fn,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user