mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Distributed][CI] Rework continuous TestCase (#153653)
1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file). 2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue. 3. Added a test template. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653 Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
parent
03e102dbe8
commit
9d922b55ef
16
test/distributed/_test_template.py
Normal file
16
test/distributed/_test_template.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
class TestTemplate(MultiProcContinousTest):
|
||||
def testABC(self):
|
||||
print(f"rank {self.rank} of {self.world_size} testing ABC")
|
||||
|
||||
def testDEF(self):
|
||||
print(f"rank {self.rank} of {self.world_size} testing DEF")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -7,13 +7,16 @@ from torch.distributed.pipelining import pipe_split, SplitPoint
|
|||
|
||||
|
||||
class ExampleCode(torch.nn.Module):
|
||||
def __init__(self, d_hid):
|
||||
def __init__(self, d_hid, splits=2):
|
||||
assert splits <= 4
|
||||
super().__init__()
|
||||
self.splits = splits
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
|
||||
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
|
|
@ -24,8 +27,14 @@ class ExampleCode(torch.nn.Module):
|
|||
pipe_split()
|
||||
x = torch.relu(x) + a_constant
|
||||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
x = torch.relu(x)
|
||||
if self.splits > 2:
|
||||
pipe_split()
|
||||
x = self.lin1(x)
|
||||
x = torch.relu(x)
|
||||
if self.splits > 3:
|
||||
pipe_split()
|
||||
x = self.lin2(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -33,12 +42,16 @@ class ModelWithKwargs(torch.nn.Module):
|
|||
DEFAULT_DHID = 512
|
||||
DEFAULT_BATCH_SIZE = 256
|
||||
|
||||
def __init__(self, d_hid: int = DEFAULT_DHID):
|
||||
def __init__(self, d_hid: int = DEFAULT_DHID, splits=2):
|
||||
assert splits <= 4
|
||||
super().__init__()
|
||||
self.splits = splits
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin2 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin3 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
|
|
@ -49,6 +62,14 @@ class ModelWithKwargs(torch.nn.Module):
|
|||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
x = torch.relu(x)
|
||||
if self.splits > 2:
|
||||
pipe_split()
|
||||
x = self.lin2(x)
|
||||
x = torch.relu(x)
|
||||
if self.splits > 3:
|
||||
pipe_split()
|
||||
x = self.lin3(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
|
||||
|
|
@ -37,6 +35,7 @@ from torch.testing._internal.common_utils import (
|
|||
check_leaked_tensors,
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
|
||||
|
|
@ -48,6 +47,8 @@ batch_size = 256
|
|||
|
||||
torch.manual_seed(0)
|
||||
|
||||
device_type = "cuda"
|
||||
|
||||
|
||||
class ScheduleTest(MultiProcContinousTest):
|
||||
@classmethod
|
||||
|
|
@ -55,15 +56,9 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
# Testing with NCCL backend
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Class-scope test fixture. Run once for entire test class, before any test starts.
|
||||
Set up the device.
|
||||
"""
|
||||
super().setUpClass()
|
||||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
|
|
@ -77,7 +72,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
x_clone = x.clone()
|
||||
|
||||
num_microbatches = 4
|
||||
num_microbatches = 2 * self.world_size
|
||||
x_mb = x.chunk(num_microbatches)[0]
|
||||
|
||||
# Create a pipeline
|
||||
|
|
@ -159,6 +154,12 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
def test_kwargs_with_tracer(self, ScheduleClass):
|
||||
# Model has two stages only, thus limiting group size to 2
|
||||
group_size = 2
|
||||
group = dist.new_group(list(range(group_size)))
|
||||
if self.rank >= group_size:
|
||||
return
|
||||
|
||||
mod = ModelWithKwargs(d_hid)
|
||||
mod.to(self.device)
|
||||
|
||||
|
|
@ -180,6 +181,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
stage = pipe.build_stage(
|
||||
self.rank,
|
||||
self.device,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
|
|
@ -188,16 +190,16 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
# Run
|
||||
if self.rank == 0:
|
||||
schedule.step(x, y=y)
|
||||
elif self.rank == self.world_size - 1:
|
||||
elif self.rank == group_size - 1:
|
||||
losses = []
|
||||
out = schedule.step(target=target, losses=losses)
|
||||
else:
|
||||
schedule.step()
|
||||
|
||||
dist.barrier()
|
||||
# dist.barrier()
|
||||
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
if self.rank == group_size - 1:
|
||||
ref_out = mod(x, y=y)
|
||||
ref_loss = loss_fn(ref_out, target)
|
||||
pipe_loss = sum(losses)
|
||||
|
|
@ -207,9 +209,8 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@parametrize("ModelClass", [MultiMLP])
|
||||
def test_grad_with_tracer(self, ScheduleClass, ModelClass):
|
||||
mod = ModelClass(d_hid)
|
||||
def test_grad_with_tracer(self, ScheduleClass):
|
||||
mod = MultiMLP(d_hid, n_layers=self.world_size)
|
||||
mod.to(self.device)
|
||||
|
||||
ref_mod = copy.deepcopy(mod)
|
||||
|
|
@ -229,7 +230,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
ref_loss.backward()
|
||||
|
||||
# Create a pipeline
|
||||
chunks = 4
|
||||
chunks = 2 * self.world_size
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
||||
pipe = pipeline(
|
||||
|
|
@ -307,7 +308,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
# Get a submodule, e.g. `layers.0` or `layers.1`
|
||||
submod_name = f"layers.{self.rank}"
|
||||
stage_module = full_mod.get_submodule(submod_name)
|
||||
chunks = 4
|
||||
chunks = 2 * self.world_size
|
||||
|
||||
if shape_inference:
|
||||
input_args = None
|
||||
|
|
@ -410,7 +411,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
num_microbatches = (
|
||||
ScheduleClass.num_microbatches
|
||||
if hasattr(ScheduleClass, "num_microbatches")
|
||||
else 8
|
||||
else 2 * self.world_size
|
||||
)
|
||||
stages = [
|
||||
PipelineStage(
|
||||
|
|
@ -518,13 +519,15 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
raise
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble])
|
||||
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
||||
print(ScheduleClass)
|
||||
if ScheduleClass is ScheduleInterleavedZeroBubble:
|
||||
n_stages = 4
|
||||
num_microbatches = 8
|
||||
num_microbatches = 2 * n_stages
|
||||
rank_stages = {
|
||||
0: [0, 2],
|
||||
1: [1, 3],
|
||||
|
|
@ -612,7 +615,9 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
raise
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
|
|
@ -717,7 +722,9 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
raise
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
|
||||
)
|
||||
|
|
@ -822,7 +829,9 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
raise
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
|
||||
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
|
|
@ -942,30 +951,4 @@ instantiate_parametrized_tests(ScheduleTest)
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if GPU and NCCL are available
|
||||
if not (
|
||||
dist.is_available()
|
||||
and dist.is_nccl_available()
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
print(
|
||||
"c10d NCCL not available or not enough GPUs, skipping tests",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 2))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
ScheduleTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
ScheduleTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
|
||||
|
||||
|
|
@ -18,11 +17,14 @@ from torch.distributed.pipelining._utils import PipeliningShapeError
|
|||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinousTest,
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
|
@ -32,6 +34,8 @@ d_hid = 512
|
|||
batch_size = 256
|
||||
chunks = 4
|
||||
|
||||
device_type = "cuda"
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
|
|
@ -66,20 +70,18 @@ class StageTest(MultiProcContinousTest):
|
|||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Class-scope test fixture. Run once for entire test class, before any test starts.
|
||||
Set up the device.
|
||||
"""
|
||||
super().setUpClass()
|
||||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
def device_type(cls) -> str:
|
||||
return device_type
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ModelClass", [ExampleCode, MultiMLP])
|
||||
def test_tracer(self, ModelClass):
|
||||
mod = ModelClass(d_hid)
|
||||
mod = ModelClass(d_hid, self.world_size)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
|
@ -119,32 +121,11 @@ class StageTest(MultiProcContinousTest):
|
|||
old_keys = mod.state_dict().keys()
|
||||
assert all(k in old_keys for k in submod_keys)
|
||||
|
||||
if self.rank == 0:
|
||||
# intended to run this code on all ranks, but the problem is if rank0 throws,
|
||||
# it won't perform the send that unblocks rank 1.
|
||||
|
||||
# TODO(whc) can't test this until fixing args/kwargs issue
|
||||
# with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
# _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
|
||||
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x.to(torch.int32))
|
||||
|
||||
# output of stage's mlp layer will be flattened by this hook, the stage should err
|
||||
handle = stage.submod.register_forward_hook(get_flatten_hook())
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(x)
|
||||
handle.remove()
|
||||
|
||||
stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ModelClass", [ModelWithKwargs])
|
||||
def test_tracer_kwargs(self, ModelClass):
|
||||
mod = ModelClass(d_hid)
|
||||
mod = ModelClass(d_hid, self.world_size)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
|
@ -221,23 +202,6 @@ class StageTest(MultiProcContinousTest):
|
|||
ref_out = full_mod(x)
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
|
||||
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x.to(torch.int32))
|
||||
|
||||
# output of stage's mlp layer will be flattened by this hook, the stage should err
|
||||
handle = stage_mod.register_forward_hook(get_flatten_hook())
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(x)
|
||||
handle.remove()
|
||||
|
||||
stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_custom_dw_with_fb_schedule(self):
|
||||
|
|
@ -298,28 +262,6 @@ class StageTest(MultiProcContinousTest):
|
|||
ref_out = full_mod(x)
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_custom_dw_errors(self):
|
||||
"""Tests expected errors are raised"""
|
||||
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
|
||||
full_mod.to(self.device)
|
||||
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
|
||||
|
||||
stage_with_dw_builder = PipelineStage(
|
||||
stage_mod,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.device,
|
||||
dw_builder=lambda: None,
|
||||
)
|
||||
with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
|
||||
stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_output_chunks_memory_usage(self):
|
||||
|
|
@ -381,31 +323,105 @@ class StageTest(MultiProcContinousTest):
|
|||
|
||||
instantiate_parametrized_tests(StageTest)
|
||||
|
||||
|
||||
class StageNegativeTest(MultiProcessTestCase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return torch.get_device_module(device_type).device_count()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def init_pg(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
store=store,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
device_id=self.device,
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle("Flaky in CI")
|
||||
def test_shape_prop_mismatch(self):
|
||||
"""Tests shape prop errors are raised"""
|
||||
self.init_pg()
|
||||
|
||||
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
|
||||
full_mod.to(self.device)
|
||||
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
stage = PipelineStage(
|
||||
stage_mod,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleGPipe(stage, chunks)
|
||||
|
||||
# Run
|
||||
def _run_step(x):
|
||||
if self.rank == 0:
|
||||
return schedule.step(x)
|
||||
else:
|
||||
return schedule.step()
|
||||
|
||||
_run_step(x)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
|
||||
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x.to(torch.int32))
|
||||
|
||||
# output of stage's mlp layer will be flattened by this hook, the stage should err
|
||||
handle = stage_mod.register_forward_hook(get_flatten_hook())
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
|
||||
_run_step(x)
|
||||
handle.remove()
|
||||
|
||||
stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
|
||||
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
|
||||
_run_step(x)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_custom_dw_errors(self):
|
||||
"""Tests expected errors are raised"""
|
||||
self.init_pg()
|
||||
|
||||
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
|
||||
full_mod.to(self.device)
|
||||
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
|
||||
|
||||
stage_with_dw_builder = PipelineStage(
|
||||
stage_mod,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.device,
|
||||
dw_builder=lambda: None,
|
||||
)
|
||||
with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
|
||||
stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if GPU and NCCL are available
|
||||
if not (
|
||||
dist.is_available()
|
||||
and dist.is_nccl_available()
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
print(
|
||||
"c10d NCCL not available or not enough GPUs, skipping tests",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 2))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
StageTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
StageTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
|
|
@ -30,9 +29,9 @@ from torch.testing._internal.common_distributed import (
|
|||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
sm_is_or_higher_than,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
|
|
@ -1044,24 +1043,4 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not torch.cuda.is_available():
|
||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", -1))
|
||||
|
||||
if world_size == -1: # Not set by external launcher
|
||||
world_size = torch.cuda.device_count()
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
ProcessGroupNCCLOpTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
|
@ -30,11 +26,15 @@ from torch.testing._internal.common_distributed import (
|
|||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
|
||||
|
||||
device_type = "cuda"
|
||||
|
||||
|
||||
# MLP Layer
|
||||
class MLPModule(torch.nn.Module):
|
||||
def __init__(self, d_hid: int):
|
||||
|
|
@ -92,29 +92,14 @@ def loss_fn(y, target, scale=1e-4):
|
|||
|
||||
|
||||
class ComposabilityTest(MultiProcContinousTest):
|
||||
world_size = 4
|
||||
|
||||
@classmethod
|
||||
def backend_str(cls) -> str:
|
||||
# Testing with NCCL backend
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Class-scope test fixture. Run once for entire test class, before any test starts.
|
||||
Set up the device.
|
||||
"""
|
||||
super().setUpClass()
|
||||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
torch.cuda.set_device(cls.device)
|
||||
|
||||
def _build_mesh(self, mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")):
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
return device_mesh
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
|
||||
full = [
|
||||
|
|
@ -216,7 +201,12 @@ class ComposabilityTest(MultiProcContinousTest):
|
|||
# https://github.com/pytorch/pytorch/issues/144530
|
||||
return
|
||||
|
||||
device_mesh = self._build_mesh((2, 2), ("dp", "pp"))
|
||||
torch.get_device_module(device_type).set_device(self.device)
|
||||
mesh_shape = (self.world_size // 2, 2)
|
||||
mesh_dim_names = ("dp", "pp")
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_mesh = device_mesh["dp"]
|
||||
|
||||
|
|
@ -292,7 +282,12 @@ class ComposabilityTest(MultiProcContinousTest):
|
|||
if TEST_WITH_ROCM:
|
||||
return
|
||||
|
||||
device_mesh = self._build_mesh((2, 2), ("dp", "pp"))
|
||||
torch.get_device_module(device_type).set_device(self.device)
|
||||
mesh_shape = (self.world_size // 2, 2)
|
||||
mesh_dim_names = ("dp", "pp")
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_mesh = device_mesh["dp"]
|
||||
|
||||
|
|
@ -376,35 +371,12 @@ class ComposabilityTest(MultiProcContinousTest):
|
|||
name = ".".join(parts)
|
||||
ref_p = ref_parameters[name]
|
||||
self.assertTrue(isinstance(p.grad, DTensor))
|
||||
torch.testing.assert_close(p.grad.full_tensor(), ref_p.grad)
|
||||
torch.testing.assert_close(
|
||||
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ComposabilityTest)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if GPU and NCCL are available
|
||||
if not (
|
||||
dist.is_available()
|
||||
and dist.is_nccl_available()
|
||||
and torch.cuda.device_count() > 3
|
||||
):
|
||||
print(
|
||||
"c10d NCCL not available or not enough GPUs, skipping tests",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 4))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
ComposabilityTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
ComposabilityTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7,16 +7,13 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinousTest,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
|
@ -47,28 +44,20 @@ device_module = torch.get_device_module(device_type)
|
|||
|
||||
@requires_nvshmem()
|
||||
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
def _init_device(self) -> None:
|
||||
# TODO: relieve this (seems to hang if without)
|
||||
device_module.set_device(self.device)
|
||||
# NOTE: required for nvshmem allocation
|
||||
torch.empty(1, device=self.device)
|
||||
|
||||
# Required by MultiProcContinousTest
|
||||
@classmethod
|
||||
def backend_str(cls) -> str:
|
||||
return "nccl"
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return device_module.device_count()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_all_to_all(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
|
|
@ -92,6 +81,8 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_all_to_all_vdev(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
|
|
@ -139,24 +130,4 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not device_module.is_available():
|
||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||
|
||||
# If launched by torchrun, these values would have been set
|
||||
rank = int(os.getenv("RANK", "-1"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# No external launcher, spawn N processes
|
||||
world_size = device_module.device_count()
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
NVSHMEMSymmetricMemoryTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import abc
|
||||
import faulthandler
|
||||
import itertools
|
||||
import logging
|
||||
|
|
@ -38,7 +37,6 @@ from torch.testing._internal.common_utils import (
|
|||
find_free_port,
|
||||
IS_SANDCASTLE,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_HPU,
|
||||
|
|
@ -689,6 +687,11 @@ class MultiProcessTestCase(TestCase):
|
|||
self.processes.append(process)
|
||||
|
||||
def _spawn_processes(self) -> None:
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
proc = torch.multiprocessing.get_context("spawn").Process
|
||||
self._start_processes(proc)
|
||||
|
||||
|
|
@ -1503,24 +1506,34 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
|||
|
||||
class MultiProcContinousTest(TestCase):
|
||||
# Class variables:
|
||||
MAIN_PROCESS_RANK = -1
|
||||
# number of test processes
|
||||
world_size: int = 2
|
||||
world_size: int = -2 # unset state
|
||||
# rank of the current process
|
||||
rank: int = -1 # unset state
|
||||
rank: int = -2 # unset state
|
||||
# Rendezvous file
|
||||
rdvz_file: Optional[str] = None
|
||||
# timeout configured per class
|
||||
timeout: timedelta = timedelta(seconds=120)
|
||||
# Poison pill for rest of tests if one of them fails
|
||||
poison_pill: bool = False
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def backend_str(cls) -> str:
|
||||
def backend_str(cls) -> Optional[str]:
|
||||
"""
|
||||
ProcessGroup backend str.
|
||||
To be customized by sub test classes, e.g. "nccl".
|
||||
Here we raise error.
|
||||
Otherwise we return None -- lazily decided by tensor.
|
||||
"""
|
||||
raise NotImplementedError("Please implement backend_str in your test class")
|
||||
return None
|
||||
|
||||
# Please override if you intend to test on specific device type
|
||||
@classmethod
|
||||
def device_type(cls) -> str:
|
||||
curr_device = torch.accelerator.current_accelerator()
|
||||
if curr_device is None:
|
||||
return "cpu"
|
||||
return curr_device.type
|
||||
|
||||
@classmethod
|
||||
def opts(cls, high_priority_stream=False):
|
||||
|
|
@ -1531,6 +1544,101 @@ class MultiProcContinousTest(TestCase):
|
|||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _init_pg(cls, rank, world_size, rdvz_file):
|
||||
assert rdvz_file is not None
|
||||
store = c10d.FileStore(rdvz_file, world_size)
|
||||
|
||||
# create nccl processgroup with opts
|
||||
c10d.init_process_group(
|
||||
backend=cls.backend_str(),
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
store=store,
|
||||
pg_options=cls.opts(),
|
||||
timeout=cls.timeout,
|
||||
)
|
||||
cls.pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
@classmethod
|
||||
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
|
||||
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
||||
test_name = test_id.split(".")[-1]
|
||||
# Get the test function from the test class
|
||||
self = cls(test_name)
|
||||
self.rank = cls.rank
|
||||
self.world_size = cls.world_size
|
||||
test_fn = getattr(self, test_name)
|
||||
# Run the test function
|
||||
test_fn(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
|
||||
# Sub tests are going to access these values, check first
|
||||
assert 0 <= rank < world_size
|
||||
# set class variables for the test class
|
||||
cls.rank = rank
|
||||
cls.world_size = world_size
|
||||
|
||||
# Initialize the process group
|
||||
cls._init_pg(rank, world_size, rdvz_file)
|
||||
|
||||
# End of bootstrap
|
||||
logger.info("Setup complete")
|
||||
|
||||
# Loop forever, waiting for a test name to run
|
||||
while True:
|
||||
test_id = task_queue.get()
|
||||
logger.debug(f"Got test {test_id}") # noqa: G004
|
||||
# None means exit
|
||||
if test_id is None:
|
||||
break
|
||||
|
||||
# Run the test
|
||||
try:
|
||||
cls._run_test_given_id(test_id)
|
||||
completion_queue.put(test_id)
|
||||
except BaseException as ex:
|
||||
# Send the exception back to the dispatcher
|
||||
completion_queue.put(ex)
|
||||
|
||||
# Termination
|
||||
logger.info("Terminating ...")
|
||||
c10d.destroy_process_group()
|
||||
|
||||
@classmethod
|
||||
def _spawn_processes(cls, world_size) -> None:
|
||||
cls.processes = []
|
||||
cls.task_queues = []
|
||||
cls.completion_queues = []
|
||||
# Need a rendezvous file for `init_process_group` purpose.
|
||||
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
|
||||
# CUDA multiprocessing requires spawn instead of fork, to make sure
|
||||
# child processes have their own memory space.
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
# The start method has already been set
|
||||
pass
|
||||
|
||||
for rank in range(int(world_size)):
|
||||
task_queue = torch.multiprocessing.Queue()
|
||||
completion_queue = torch.multiprocessing.Queue()
|
||||
process = torch.multiprocessing.Process(
|
||||
target=cls._worker_loop,
|
||||
name="process " + str(rank),
|
||||
daemon=True, # so that child processes will exit if parent decides to terminate
|
||||
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
|
||||
)
|
||||
process.start()
|
||||
cls.processes.append(process)
|
||||
cls.task_queues.append(task_queue)
|
||||
cls.completion_queues.append(completion_queue)
|
||||
logger.info(
|
||||
"Started process %s with pid %s", rank, process.pid
|
||||
) # noqa: UP031
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
|
|
@ -1538,30 +1646,18 @@ class MultiProcContinousTest(TestCase):
|
|||
Set up the process group.
|
||||
"""
|
||||
super().setUpClass()
|
||||
if not 0 <= cls.rank < cls.world_size:
|
||||
raise RuntimeError(
|
||||
"Rank must be set and in the range of 0 to world_size. "
|
||||
f"World size: {cls.world_size} Rank: {cls.rank}"
|
||||
)
|
||||
if cls.rdvz_file:
|
||||
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
|
||||
else:
|
||||
# torchrun takes care of rendezvous
|
||||
store = None
|
||||
opts = cls.opts()
|
||||
backend = cls.backend_str()
|
||||
print(f"Testing {backend=}")
|
||||
# create nccl processgroup with opts
|
||||
c10d.init_process_group(
|
||||
backend=backend,
|
||||
world_size=cls.world_size,
|
||||
rank=cls.rank,
|
||||
store=store,
|
||||
pg_options=opts,
|
||||
timeout=cls.timeout,
|
||||
|
||||
# Use device count as world size
|
||||
device_type = cls.device_type()
|
||||
cls.world_size = torch.get_device_module(device_type).device_count()
|
||||
if cls.world_size == 0:
|
||||
raise unittest.SkipTest(f"No {device_type} devices available")
|
||||
|
||||
logger.info(
|
||||
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
|
||||
)
|
||||
cls.pg = c10d.distributed_c10d._get_default_group()
|
||||
print(f"Rank {cls.rank} setup complete")
|
||||
|
||||
cls._spawn_processes(cls.world_size)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
|
@ -1569,37 +1665,91 @@ class MultiProcContinousTest(TestCase):
|
|||
Class-scope test fixture. Run once for entire test class, after all tests finish.
|
||||
Tear down the process group.
|
||||
"""
|
||||
c10d.destroy_process_group()
|
||||
super().tearDownClass()
|
||||
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
|
||||
# Enqueue "None" to all workers to tell them to exit
|
||||
for task_queue in cls.task_queues:
|
||||
task_queue.put(None)
|
||||
|
||||
# Wait for all workers to exit
|
||||
for process in cls.processes:
|
||||
process.join()
|
||||
|
||||
# Clear up the rendezvous file
|
||||
if cls.rdvz_file:
|
||||
try:
|
||||
os.remove(cls.rdvz_file)
|
||||
except OSError:
|
||||
pass
|
||||
print(f"Rank {cls.rank} teardown complete")
|
||||
try:
|
||||
os.remove(cls.rdvz_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run_rank(
|
||||
cls,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
rdvz_file: Optional[str] = None,
|
||||
):
|
||||
logger.info(f"Class {cls.__name__} finished") # noqa: G004
|
||||
super().tearDownClass()
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
|
||||
In this entry point, we set the class variables for the test class.
|
||||
Then we run all tests.
|
||||
|
||||
Note:
|
||||
- This helper only works for a subclass of `MultiProcContinousTest`.
|
||||
|
||||
Example:
|
||||
- See `test_c10d_ops_nccl.py`.
|
||||
Test fixture. Run before each test.
|
||||
"""
|
||||
# set class variables for the test class
|
||||
cls.rank = rank
|
||||
cls.world_size = world_size
|
||||
cls.rdvz_file = rdvz_file
|
||||
# Launch tests via `common_utils` infra
|
||||
run_tests()
|
||||
super().setUp()
|
||||
|
||||
# I am the dispatcher
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
|
||||
# If this test class hits an exception in one test, skip the rest of tests
|
||||
if self.__class__.poison_pill:
|
||||
raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}")
|
||||
|
||||
# Enqueue "current test" to all workers
|
||||
for i, task_queue in enumerate(self.task_queues):
|
||||
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
|
||||
task_queue.put(self.id())
|
||||
|
||||
def _worker_run_main_wait(self, fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
|
||||
# Wait for the workers to finish the test
|
||||
for i, completion_queue in enumerate(self.completion_queues):
|
||||
rv = completion_queue.get()
|
||||
if isinstance(rv, BaseException):
|
||||
# Hit an exception, re-raise it in the main process.
|
||||
logger.warning(
|
||||
f"Detected failure from Rank {i} in: {self.id()}, " # noqa: G004
|
||||
f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004
|
||||
)
|
||||
# Poison rest of tests (because ProcessGroup may be not
|
||||
# re-usable now)
|
||||
self.__class__.poison_pill = True
|
||||
raise rv
|
||||
|
||||
# Success
|
||||
assert rv == self.id()
|
||||
logger.debug(
|
||||
f"Main proc detected rank {i} finished {self.id()}" # noqa: G004
|
||||
)
|
||||
else:
|
||||
# Worker just runs the test
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
# The main process spawns N subprocesses that run the test.
|
||||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(
|
||||
self, method_name: str = "runTest", methodName: str = "runTest"
|
||||
) -> None:
|
||||
# methodName is the correct naming in unittest and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != "runTest":
|
||||
method_name = methodName
|
||||
super().__init__(method_name)
|
||||
try:
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self._worker_run_main_wait(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != "runTest":
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(
|
||||
f"no such test method in {self.__class__}: {methodName}"
|
||||
) from e
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user