[Distributed][CI] Rework continuous TestCase (#153653)

1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file).

2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue.

3. Added a test template.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653
Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
Ke Wen 2025-05-18 23:18:53 -07:00 committed by PyTorch MergeBot
parent c54b9f2969
commit 0d5c628a6e
4 changed files with 201 additions and 121 deletions

View File

@ -0,0 +1,16 @@
# Owner(s): ["oncall: distributed"]
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import run_tests
class TestTemplate(MultiProcContinousTest):
def testABC(self):
print(f"rank {self.rank} of {self.world_size} testing ABC")
def testDEF(self):
print(f"rank {self.rank} of {self.world_size} testing DEF")
if __name__ == "__main__":
run_tests()

View File

@ -11,7 +11,6 @@
import math
import os
import sys
import tempfile
import torch
import torch.distributed as c10d
@ -30,9 +29,9 @@ from torch.testing._internal.common_distributed import (
requires_nccl,
requires_nccl_version,
sm_is_or_higher_than,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
@ -1044,24 +1043,4 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
if __name__ == "__main__":
if not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", -1))
if world_size == -1: # Not set by external launcher
world_size = torch.cuda.device_count()
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ProcessGroupNCCLOpTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -7,16 +7,13 @@
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
TEST_SKIPS,
)
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
@ -47,28 +44,20 @@ device_module = torch.get_device_module(device_type)
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def setUp(self) -> None:
super().setUp()
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)
# Required by MultiProcContinousTest
@classmethod
def backend_str(cls) -> str:
return "nccl"
@property
def world_size(self) -> int:
return device_module.device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -92,6 +81,8 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
@skipIfRocm
def test_nvshmem_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -139,24 +130,4 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
if __name__ == "__main__":
if not device_module.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
# If launched by torchrun, these values would have been set
rank = int(os.getenv("RANK", "-1"))
world_size = int(os.getenv("WORLD_SIZE", "-1"))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
else:
# No external launcher, spawn N processes
world_size = device_module.device_count()
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
NVSHMEMSymmetricMemoryTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)
run_tests()

View File

@ -1,6 +1,5 @@
# mypy: ignore-errors
import abc
import faulthandler
import itertools
import logging
@ -38,7 +37,6 @@ from torch.testing._internal.common_utils import (
find_free_port,
IS_SANDCASTLE,
retry_on_connect_failures,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_HPU,
@ -1502,24 +1500,29 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
class MultiProcContinousTest(TestCase):
# Class variables:
MAIN_PROCESS_RANK = -1
# number of test processes
world_size: int = 2
world_size: int = -2 # unset state
# rank of the current process
rank: int = -1 # unset state
rank: int = -2 # unset state
# Rendezvous file
rdvz_file: Optional[str] = None
# timeout configured per class
timeout: timedelta = timedelta(seconds=120)
@classmethod
@abc.abstractmethod
def backend_str(cls) -> str:
def backend_str(cls) -> Optional[str]:
"""
ProcessGroup backend str.
To be customized by sub test classes, e.g. "nccl".
Here we raise error.
Otherwise we return None -- lazily decided by tensor.
"""
raise NotImplementedError("Please implement backend_str in your test class")
return None
# Please override if you intend to test on specific device type
@classmethod
def device_type(cls) -> str:
return torch.accelerator.current_accelerator().type
@classmethod
def opts(cls, high_priority_stream=False):
@ -1530,6 +1533,91 @@ class MultiProcContinousTest(TestCase):
"""
return None
@classmethod
def _init_pg(cls, rank, world_size, rdvz_file):
assert rdvz_file is not None
store = c10d.FileStore(rdvz_file, world_size)
# create nccl processgroup with opts
c10d.init_process_group(
backend=cls.backend_str(),
world_size=world_size,
rank=rank,
store=store,
pg_options=cls.opts(),
timeout=cls.timeout,
)
cls.pg = c10d.distributed_c10d._get_default_group()
@classmethod
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
test_name = test_id.split(".")[-1]
# Get the test function from the test class
self = cls(test_name)
self.rank = cls.rank
self.world_size = cls.world_size
test_fn = getattr(self, test_name)
# Run the test function
test_fn(**kwargs)
@classmethod
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
# Sub tests are going to access these values, check first
assert 0 <= rank < world_size
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
# Initialize the process group
cls._init_pg(rank, world_size, rdvz_file)
# End of bootstrap
logger.info("Setup complete")
# Loop forever, waiting for a test name to run
while True:
test_id = task_queue.get()
logger.debug(f"Got test {test_id}") # noqa: G004
if test_id is None:
break
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
# Termination
logger.info("Terminating ...")
c10d.destroy_process_group()
@classmethod
def _spawn_processes(cls, world_size) -> None:
cls.processes = []
cls.task_queues = []
cls.completion_queues = []
# Need a rendezvous file for `init_process_group` purpose.
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
# CUDA multiprocessing requires spawn instead of fork, to make sure
# child processes have their own memory space.
if torch.multiprocessing.get_start_method(allow_none=True) != "spawn":
torch.multiprocessing.set_start_method("spawn")
for rank in range(int(world_size)):
task_queue = torch.multiprocessing.Queue()
completion_queue = torch.multiprocessing.Queue()
process = torch.multiprocessing.Process(
target=cls._worker_loop,
name="process " + str(rank),
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
)
process.start()
cls.processes.append(process)
cls.task_queues.append(task_queue)
cls.completion_queues.append(completion_queue)
logger.info(
"Started process %s with pid %s", rank, process.pid
) # noqa: UP031
@classmethod
def setUpClass(cls):
"""
@ -1537,30 +1625,18 @@ class MultiProcContinousTest(TestCase):
Set up the process group.
"""
super().setUpClass()
if not 0 <= cls.rank < cls.world_size:
raise RuntimeError(
"Rank must be set and in the range of 0 to world_size. "
f"World size: {cls.world_size} Rank: {cls.rank}"
)
if cls.rdvz_file:
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
else:
# torchrun takes care of rendezvous
store = None
opts = cls.opts()
backend = cls.backend_str()
print(f"Testing {backend=}")
# create nccl processgroup with opts
c10d.init_process_group(
backend=backend,
world_size=cls.world_size,
rank=cls.rank,
store=store,
pg_options=opts,
timeout=cls.timeout,
# Use device count as world size
device_type = cls.device_type()
cls.world_size = torch.get_device_module(device_type).device_count()
if cls.world_size == 0:
raise unittest.SkipTest(f"No {device_type} devices available")
logger.info(
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
)
cls.pg = c10d.distributed_c10d._get_default_group()
print(f"Rank {cls.rank} setup complete")
cls._spawn_processes(cls.world_size)
@classmethod
def tearDownClass(cls):
@ -1568,37 +1644,75 @@ class MultiProcContinousTest(TestCase):
Class-scope test fixture. Run once for entire test class, after all tests finish.
Tear down the process group.
"""
c10d.destroy_process_group()
super().tearDownClass()
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
# Enqueue "None" to all workers to tell them to exit
for task_queue in cls.task_queues:
task_queue.put(None)
# Wait for all workers to exit
for process in cls.processes:
process.join()
# Clear up the rendezvous file
if cls.rdvz_file:
try:
os.remove(cls.rdvz_file)
except OSError:
pass
print(f"Rank {cls.rank} teardown complete")
try:
os.remove(cls.rdvz_file)
except OSError:
pass
@classmethod
def run_rank(
cls,
rank: int,
world_size: int,
rdvz_file: Optional[str] = None,
):
logger.info(f"Class {cls.__name__} finished") # noqa: G004
super().tearDownClass()
def setUp(self) -> None:
"""
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
In this entry point, we set the class variables for the test class.
Then we run all tests.
Note:
- This helper only works for a subclass of `MultiProcContinousTest`.
Example:
- See `test_c10d_ops_nccl.py`.
Test fixture. Run before each test.
"""
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
cls.rdvz_file = rdvz_file
# Launch tests via `common_utils` infra
run_tests()
super().setUp()
# I am the dispatcher
self.rank = self.MAIN_PROCESS_RANK
# Enqueue "current test" to all workers
for i, task_queue in enumerate(self.task_queues):
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
task_queue.put(self.id())
def _worker_run_main_wait(self, fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
test_id = completion_queue.get()
assert test_id == self.id()
logger.debug(
f"Main proc detected rank {i} finished {test_id}" # noqa: G004
)
else:
# Worker just runs the test
fn()
return types.MethodType(wrapper, self)
# The main process spawns N subprocesses that run the test.
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(
self, method_name: str = "runTest", methodName: str = "runTest"
) -> None:
# methodName is the correct naming in unittest and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != "runTest":
method_name = methodName
super().__init__(method_name)
try:
fn = getattr(self, method_name)
setattr(self, method_name, self._worker_run_main_wait(fn))
except AttributeError as e:
if methodName != "runTest":
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f"no such test method in {self.__class__}: {methodName}"
) from e