[Distributed][CI] Rework continuous TestCase (#153653)

1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file).

2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue.

3. Added a test template.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653
Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
Ke Wen 2025-05-23 23:01:28 -07:00 committed by PyTorch MergeBot
parent 03e102dbe8
commit 9d922b55ef
8 changed files with 437 additions and 329 deletions

View File

@ -0,0 +1,16 @@
# Owner(s): ["oncall: distributed"]
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import run_tests
class TestTemplate(MultiProcContinousTest):
def testABC(self):
print(f"rank {self.rank} of {self.world_size} testing ABC")
def testDEF(self):
print(f"rank {self.rank} of {self.world_size} testing DEF")
if __name__ == "__main__":
run_tests()

View File

@ -7,13 +7,16 @@ from torch.distributed.pipelining import pipe_split, SplitPoint
class ExampleCode(torch.nn.Module):
def __init__(self, d_hid):
def __init__(self, d_hid, splits=2):
assert splits <= 4
super().__init__()
self.splits = splits
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = torch.mm(x, self.mm_param0)
@ -24,8 +27,14 @@ class ExampleCode(torch.nn.Module):
pipe_split()
x = torch.relu(x) + a_constant
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 2:
pipe_split()
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 3:
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x
@ -33,12 +42,16 @@ class ModelWithKwargs(torch.nn.Module):
DEFAULT_DHID = 512
DEFAULT_BATCH_SIZE = 256
def __init__(self, d_hid: int = DEFAULT_DHID):
def __init__(self, d_hid: int = DEFAULT_DHID, splits=2):
assert splits <= 4
super().__init__()
self.splits = splits
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
self.lin3 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
x = torch.mm(x, self.mm_param0)
@ -49,6 +62,14 @@ class ModelWithKwargs(torch.nn.Module):
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
if self.splits > 2:
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
if self.splits > 3:
pipe_split()
x = self.lin3(x)
x = torch.relu(x)
return x

View File

@ -2,8 +2,6 @@
# Owner(s): ["oncall: distributed"]
import copy
import logging
import os
import sys
import tempfile
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
@ -37,6 +35,7 @@ from torch.testing._internal.common_utils import (
check_leaked_tensors,
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
)
@ -48,6 +47,8 @@ batch_size = 256
torch.manual_seed(0)
device_type = "cuda"
class ScheduleTest(MultiProcContinousTest):
@classmethod
@ -55,15 +56,9 @@ class ScheduleTest(MultiProcContinousTest):
# Testing with NCCL backend
return "nccl"
@classmethod
def setUpClass(cls):
"""
Class-scope test fixture. Run once for entire test class, before any test starts.
Set up the device.
"""
super().setUpClass()
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@ -77,7 +72,7 @@ class ScheduleTest(MultiProcContinousTest):
x = torch.randn(batch_size, d_hid, device=self.device)
x_clone = x.clone()
num_microbatches = 4
num_microbatches = 2 * self.world_size
x_mb = x.chunk(num_microbatches)[0]
# Create a pipeline
@ -159,6 +154,12 @@ class ScheduleTest(MultiProcContinousTest):
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_kwargs_with_tracer(self, ScheduleClass):
# Model has two stages only, thus limiting group size to 2
group_size = 2
group = dist.new_group(list(range(group_size)))
if self.rank >= group_size:
return
mod = ModelWithKwargs(d_hid)
mod.to(self.device)
@ -180,6 +181,7 @@ class ScheduleTest(MultiProcContinousTest):
stage = pipe.build_stage(
self.rank,
self.device,
group=group,
)
# Attach to a schedule
@ -188,16 +190,16 @@ class ScheduleTest(MultiProcContinousTest):
# Run
if self.rank == 0:
schedule.step(x, y=y)
elif self.rank == self.world_size - 1:
elif self.rank == group_size - 1:
losses = []
out = schedule.step(target=target, losses=losses)
else:
schedule.step()
dist.barrier()
# dist.barrier()
# Last rank checks result
if self.rank == self.world_size - 1:
if self.rank == group_size - 1:
ref_out = mod(x, y=y)
ref_loss = loss_fn(ref_out, target)
pipe_loss = sum(losses)
@ -207,9 +209,8 @@ class ScheduleTest(MultiProcContinousTest):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("ModelClass", [MultiMLP])
def test_grad_with_tracer(self, ScheduleClass, ModelClass):
mod = ModelClass(d_hid)
def test_grad_with_tracer(self, ScheduleClass):
mod = MultiMLP(d_hid, n_layers=self.world_size)
mod.to(self.device)
ref_mod = copy.deepcopy(mod)
@ -229,7 +230,7 @@ class ScheduleTest(MultiProcContinousTest):
ref_loss.backward()
# Create a pipeline
chunks = 4
chunks = 2 * self.world_size
x_mb = x.chunk(chunks)[0]
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
@ -307,7 +308,7 @@ class ScheduleTest(MultiProcContinousTest):
# Get a submodule, e.g. `layers.0` or `layers.1`
submod_name = f"layers.{self.rank}"
stage_module = full_mod.get_submodule(submod_name)
chunks = 4
chunks = 2 * self.world_size
if shape_inference:
input_args = None
@ -410,7 +411,7 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches = (
ScheduleClass.num_microbatches
if hasattr(ScheduleClass, "num_microbatches")
else 8
else 2 * self.world_size
)
stages = [
PipelineStage(
@ -518,13 +519,15 @@ class ScheduleTest(MultiProcContinousTest):
raise
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skip_but_pass_in_sandcastle_if(
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble])
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
print(ScheduleClass)
if ScheduleClass is ScheduleInterleavedZeroBubble:
n_stages = 4
num_microbatches = 8
num_microbatches = 2 * n_stages
rank_stages = {
0: [0, 2],
1: [1, 3],
@ -612,7 +615,9 @@ class ScheduleTest(MultiProcContinousTest):
raise
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skip_but_pass_in_sandcastle_if(
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
)
@parametrize(
"ScheduleClass",
[
@ -717,7 +722,9 @@ class ScheduleTest(MultiProcContinousTest):
raise
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skip_but_pass_in_sandcastle_if(
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
)
@parametrize(
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
)
@ -822,7 +829,9 @@ class ScheduleTest(MultiProcContinousTest):
raise
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skip_but_pass_in_sandcastle_if(
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
@ -942,30 +951,4 @@ instantiate_parametrized_tests(ScheduleTest)
if __name__ == "__main__":
# Check if GPU and NCCL are available
if not (
dist.is_available()
and dist.is_nccl_available()
and torch.cuda.device_count() > 1
):
print(
"c10d NCCL not available or not enough GPUs, skipping tests",
file=sys.stderr,
)
sys.exit(0)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", 2))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ScheduleTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ScheduleTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -1,8 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
@ -18,11 +17,14 @@ from torch.distributed.pipelining._utils import PipeliningShapeError
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
MultiProcessTestCase,
requires_nccl,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
)
from torch.utils._pytree import tree_map_only
@ -32,6 +34,8 @@ d_hid = 512
batch_size = 256
chunks = 4
device_type = "cuda"
torch.manual_seed(0)
@ -66,20 +70,18 @@ class StageTest(MultiProcContinousTest):
return "nccl"
@classmethod
def setUpClass(cls):
"""
Class-scope test fixture. Run once for entire test class, before any test starts.
Set up the device.
"""
super().setUpClass()
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
def device_type(cls) -> str:
return device_type
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ExampleCode, MultiMLP])
def test_tracer(self, ModelClass):
mod = ModelClass(d_hid)
mod = ModelClass(d_hid, self.world_size)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
@ -119,32 +121,11 @@ class StageTest(MultiProcContinousTest):
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
if self.rank == 0:
# intended to run this code on all ranks, but the problem is if rank0 throws,
# it won't perform the send that unblocks rank 1.
# TODO(whc) can't test this until fixing args/kwargs issue
# with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
# _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage.submod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ModelWithKwargs])
def test_tracer_kwargs(self, ModelClass):
mod = ModelClass(d_hid)
mod = ModelClass(d_hid, self.world_size)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
@ -221,23 +202,6 @@ class StageTest(MultiProcContinousTest):
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage_mod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_custom_dw_with_fb_schedule(self):
@ -298,28 +262,6 @@ class StageTest(MultiProcContinousTest):
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_custom_dw_errors(self):
"""Tests expected errors are raised"""
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
full_mod.to(self.device)
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
stage_with_dw_builder = PipelineStage(
stage_mod,
self.rank,
self.world_size,
self.device,
dw_builder=lambda: None,
)
with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_output_chunks_memory_usage(self):
@ -381,31 +323,105 @@ class StageTest(MultiProcContinousTest):
instantiate_parametrized_tests(StageTest)
class StageNegativeTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return torch.get_device_module(device_type).device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_pg(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=self.device,
)
@requires_nccl()
@skip_but_pass_in_sandcastle("Flaky in CI")
def test_shape_prop_mismatch(self):
"""Tests shape prop errors are raised"""
self.init_pg()
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
full_mod.to(self.device)
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
x = torch.randn(batch_size, d_hid, device=self.device)
stage = PipelineStage(
stage_mod,
self.rank,
self.world_size,
self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
def _run_step(x):
if self.rank == 0:
return schedule.step(x)
else:
return schedule.step()
_run_step(x)
if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage_mod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_custom_dw_errors(self):
"""Tests expected errors are raised"""
self.init_pg()
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
full_mod.to(self.device)
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
stage_with_dw_builder = PipelineStage(
stage_mod,
self.rank,
self.world_size,
self.device,
dw_builder=lambda: None,
)
with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
if __name__ == "__main__":
# Check if GPU and NCCL are available
if not (
dist.is_available()
and dist.is_nccl_available()
and torch.cuda.device_count() > 1
):
print(
"c10d NCCL not available or not enough GPUs, skipping tests",
file=sys.stderr,
)
sys.exit(0)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", 2))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
StageTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
StageTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -11,7 +11,6 @@
import math
import os
import sys
import tempfile
import torch
import torch.distributed as c10d
@ -30,9 +29,9 @@ from torch.testing._internal.common_distributed import (
requires_nccl,
requires_nccl_version,
sm_is_or_higher_than,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
@ -1044,24 +1043,4 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
if __name__ == "__main__":
if not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", -1))
if world_size == -1: # Not set by external launcher
world_size = torch.cuda.device_count()
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ProcessGroupNCCLOpTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -1,11 +1,7 @@
# Owner(s): ["oncall: distributed"]
import copy
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
@ -30,11 +26,15 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_ROCM,
)
device_type = "cuda"
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
@ -92,29 +92,14 @@ def loss_fn(y, target, scale=1e-4):
class ComposabilityTest(MultiProcContinousTest):
world_size = 4
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
@classmethod
def setUpClass(cls):
"""
Class-scope test fixture. Run once for entire test class, before any test starts.
Set up the device.
"""
super().setUpClass()
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
torch.cuda.set_device(cls.device)
def _build_mesh(self, mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")):
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
return device_mesh
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
full = [
@ -216,7 +201,12 @@ class ComposabilityTest(MultiProcContinousTest):
# https://github.com/pytorch/pytorch/issues/144530
return
device_mesh = self._build_mesh((2, 2), ("dp", "pp"))
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
@ -292,7 +282,12 @@ class ComposabilityTest(MultiProcContinousTest):
if TEST_WITH_ROCM:
return
device_mesh = self._build_mesh((2, 2), ("dp", "pp"))
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
@ -376,35 +371,12 @@ class ComposabilityTest(MultiProcContinousTest):
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
torch.testing.assert_close(p.grad.full_tensor(), ref_p.grad)
torch.testing.assert_close(
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
)
instantiate_parametrized_tests(ComposabilityTest)
if __name__ == "__main__":
# Check if GPU and NCCL are available
if not (
dist.is_available()
and dist.is_nccl_available()
and torch.cuda.device_count() > 3
):
print(
"c10d NCCL not available or not enough GPUs, skipping tests",
file=sys.stderr,
)
sys.exit(0)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", 4))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ComposabilityTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ComposabilityTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -7,16 +7,13 @@
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
TEST_SKIPS,
)
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
@ -47,28 +44,20 @@ device_module = torch.get_device_module(device_type)
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def setUp(self) -> None:
super().setUp()
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)
# Required by MultiProcContinousTest
@classmethod
def backend_str(cls) -> str:
return "nccl"
@property
def world_size(self) -> int:
return device_module.device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -92,6 +81,8 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
@skipIfRocm
def test_nvshmem_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -139,24 +130,4 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
if __name__ == "__main__":
if not device_module.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
# If launched by torchrun, these values would have been set
rank = int(os.getenv("RANK", "-1"))
world_size = int(os.getenv("WORLD_SIZE", "-1"))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
else:
# No external launcher, spawn N processes
world_size = device_module.device_count()
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
NVSHMEMSymmetricMemoryTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -1,6 +1,5 @@
# mypy: ignore-errors
import abc
import faulthandler
import itertools
import logging
@ -38,7 +37,6 @@ from torch.testing._internal.common_utils import (
find_free_port,
IS_SANDCASTLE,
retry_on_connect_failures,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_HPU,
@ -689,6 +687,11 @@ class MultiProcessTestCase(TestCase):
self.processes.append(process)
def _spawn_processes(self) -> None:
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
proc = torch.multiprocessing.get_context("spawn").Process
self._start_processes(proc)
@ -1503,24 +1506,34 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
class MultiProcContinousTest(TestCase):
# Class variables:
MAIN_PROCESS_RANK = -1
# number of test processes
world_size: int = 2
world_size: int = -2 # unset state
# rank of the current process
rank: int = -1 # unset state
rank: int = -2 # unset state
# Rendezvous file
rdvz_file: Optional[str] = None
# timeout configured per class
timeout: timedelta = timedelta(seconds=120)
# Poison pill for rest of tests if one of them fails
poison_pill: bool = False
@classmethod
@abc.abstractmethod
def backend_str(cls) -> str:
def backend_str(cls) -> Optional[str]:
"""
ProcessGroup backend str.
To be customized by sub test classes, e.g. "nccl".
Here we raise error.
Otherwise we return None -- lazily decided by tensor.
"""
raise NotImplementedError("Please implement backend_str in your test class")
return None
# Please override if you intend to test on specific device type
@classmethod
def device_type(cls) -> str:
curr_device = torch.accelerator.current_accelerator()
if curr_device is None:
return "cpu"
return curr_device.type
@classmethod
def opts(cls, high_priority_stream=False):
@ -1531,6 +1544,101 @@ class MultiProcContinousTest(TestCase):
"""
return None
@classmethod
def _init_pg(cls, rank, world_size, rdvz_file):
assert rdvz_file is not None
store = c10d.FileStore(rdvz_file, world_size)
# create nccl processgroup with opts
c10d.init_process_group(
backend=cls.backend_str(),
world_size=world_size,
rank=rank,
store=store,
pg_options=cls.opts(),
timeout=cls.timeout,
)
cls.pg = c10d.distributed_c10d._get_default_group()
@classmethod
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
test_name = test_id.split(".")[-1]
# Get the test function from the test class
self = cls(test_name)
self.rank = cls.rank
self.world_size = cls.world_size
test_fn = getattr(self, test_name)
# Run the test function
test_fn(**kwargs)
@classmethod
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
# Sub tests are going to access these values, check first
assert 0 <= rank < world_size
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
# Initialize the process group
cls._init_pg(rank, world_size, rdvz_file)
# End of bootstrap
logger.info("Setup complete")
# Loop forever, waiting for a test name to run
while True:
test_id = task_queue.get()
logger.debug(f"Got test {test_id}") # noqa: G004
# None means exit
if test_id is None:
break
# Run the test
try:
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex:
# Send the exception back to the dispatcher
completion_queue.put(ex)
# Termination
logger.info("Terminating ...")
c10d.destroy_process_group()
@classmethod
def _spawn_processes(cls, world_size) -> None:
cls.processes = []
cls.task_queues = []
cls.completion_queues = []
# Need a rendezvous file for `init_process_group` purpose.
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
# CUDA multiprocessing requires spawn instead of fork, to make sure
# child processes have their own memory space.
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
# The start method has already been set
pass
for rank in range(int(world_size)):
task_queue = torch.multiprocessing.Queue()
completion_queue = torch.multiprocessing.Queue()
process = torch.multiprocessing.Process(
target=cls._worker_loop,
name="process " + str(rank),
daemon=True, # so that child processes will exit if parent decides to terminate
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
)
process.start()
cls.processes.append(process)
cls.task_queues.append(task_queue)
cls.completion_queues.append(completion_queue)
logger.info(
"Started process %s with pid %s", rank, process.pid
) # noqa: UP031
@classmethod
def setUpClass(cls):
"""
@ -1538,30 +1646,18 @@ class MultiProcContinousTest(TestCase):
Set up the process group.
"""
super().setUpClass()
if not 0 <= cls.rank < cls.world_size:
raise RuntimeError(
"Rank must be set and in the range of 0 to world_size. "
f"World size: {cls.world_size} Rank: {cls.rank}"
)
if cls.rdvz_file:
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
else:
# torchrun takes care of rendezvous
store = None
opts = cls.opts()
backend = cls.backend_str()
print(f"Testing {backend=}")
# create nccl processgroup with opts
c10d.init_process_group(
backend=backend,
world_size=cls.world_size,
rank=cls.rank,
store=store,
pg_options=opts,
timeout=cls.timeout,
# Use device count as world size
device_type = cls.device_type()
cls.world_size = torch.get_device_module(device_type).device_count()
if cls.world_size == 0:
raise unittest.SkipTest(f"No {device_type} devices available")
logger.info(
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
)
cls.pg = c10d.distributed_c10d._get_default_group()
print(f"Rank {cls.rank} setup complete")
cls._spawn_processes(cls.world_size)
@classmethod
def tearDownClass(cls):
@ -1569,37 +1665,91 @@ class MultiProcContinousTest(TestCase):
Class-scope test fixture. Run once for entire test class, after all tests finish.
Tear down the process group.
"""
c10d.destroy_process_group()
super().tearDownClass()
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
# Enqueue "None" to all workers to tell them to exit
for task_queue in cls.task_queues:
task_queue.put(None)
# Wait for all workers to exit
for process in cls.processes:
process.join()
# Clear up the rendezvous file
if cls.rdvz_file:
try:
os.remove(cls.rdvz_file)
except OSError:
pass
print(f"Rank {cls.rank} teardown complete")
try:
os.remove(cls.rdvz_file)
except OSError:
pass
@classmethod
def run_rank(
cls,
rank: int,
world_size: int,
rdvz_file: Optional[str] = None,
):
logger.info(f"Class {cls.__name__} finished") # noqa: G004
super().tearDownClass()
def setUp(self) -> None:
"""
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
In this entry point, we set the class variables for the test class.
Then we run all tests.
Note:
- This helper only works for a subclass of `MultiProcContinousTest`.
Example:
- See `test_c10d_ops_nccl.py`.
Test fixture. Run before each test.
"""
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
cls.rdvz_file = rdvz_file
# Launch tests via `common_utils` infra
run_tests()
super().setUp()
# I am the dispatcher
self.rank = self.MAIN_PROCESS_RANK
# If this test class hits an exception in one test, skip the rest of tests
if self.__class__.poison_pill:
raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}")
# Enqueue "current test" to all workers
for i, task_queue in enumerate(self.task_queues):
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
task_queue.put(self.id())
def _worker_run_main_wait(self, fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
rv = completion_queue.get()
if isinstance(rv, BaseException):
# Hit an exception, re-raise it in the main process.
logger.warning(
f"Detected failure from Rank {i} in: {self.id()}, " # noqa: G004
f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004
)
# Poison rest of tests (because ProcessGroup may be not
# re-usable now)
self.__class__.poison_pill = True
raise rv
# Success
assert rv == self.id()
logger.debug(
f"Main proc detected rank {i} finished {self.id()}" # noqa: G004
)
else:
# Worker just runs the test
fn()
return types.MethodType(wrapper, self)
# The main process spawns N subprocesses that run the test.
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(
self, method_name: str = "runTest", methodName: str = "runTest"
) -> None:
# methodName is the correct naming in unittest and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != "runTest":
method_name = methodName
super().__init__(method_name)
try:
fn = getattr(self, method_name)
setattr(self, method_name, self._worker_run_main_wait(fn))
except AttributeError as e:
if methodName != "runTest":
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f"no such test method in {self.__class__}: {methodName}"
) from e