mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Distributed][CI] Rework continuous TestCase (#153653)
1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file). 2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue. 3. Added a test template. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653 Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
parent
c54b9f2969
commit
0d5c628a6e
16
test/distributed/_test_template.py
Normal file
16
test/distributed/_test_template.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
class TestTemplate(MultiProcContinousTest):
|
||||
def testABC(self):
|
||||
print(f"rank {self.rank} of {self.world_size} testing ABC")
|
||||
|
||||
def testDEF(self):
|
||||
print(f"rank {self.rank} of {self.world_size} testing DEF")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -11,7 +11,6 @@
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
|
|
@ -30,9 +29,9 @@ from torch.testing._internal.common_distributed import (
|
|||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
sm_is_or_higher_than,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
|
|
@ -1044,24 +1043,4 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not torch.cuda.is_available():
|
||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", -1))
|
||||
|
||||
if world_size == -1: # Not set by external launcher
|
||||
world_size = torch.cuda.device_count()
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
ProcessGroupNCCLOpTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7,16 +7,13 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinousTest,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
|
@ -47,28 +44,20 @@ device_module = torch.get_device_module(device_type)
|
|||
|
||||
@requires_nvshmem()
|
||||
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
def _init_device(self) -> None:
|
||||
# TODO: relieve this (seems to hang if without)
|
||||
device_module.set_device(self.device)
|
||||
# NOTE: required for nvshmem allocation
|
||||
torch.empty(1, device=self.device)
|
||||
|
||||
# Required by MultiProcContinousTest
|
||||
@classmethod
|
||||
def backend_str(cls) -> str:
|
||||
return "nccl"
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return device_module.device_count()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_all_to_all(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
|
|
@ -92,6 +81,8 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_all_to_all_vdev(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
|
|
@ -139,24 +130,4 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not device_module.is_available():
|
||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||
|
||||
# If launched by torchrun, these values would have been set
|
||||
rank = int(os.getenv("RANK", "-1"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# No external launcher, spawn N processes
|
||||
world_size = device_module.device_count()
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
NVSHMEMSymmetricMemoryTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import abc
|
||||
import faulthandler
|
||||
import itertools
|
||||
import logging
|
||||
|
|
@ -38,7 +37,6 @@ from torch.testing._internal.common_utils import (
|
|||
find_free_port,
|
||||
IS_SANDCASTLE,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_HPU,
|
||||
|
|
@ -1502,24 +1500,29 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
|||
|
||||
class MultiProcContinousTest(TestCase):
|
||||
# Class variables:
|
||||
MAIN_PROCESS_RANK = -1
|
||||
# number of test processes
|
||||
world_size: int = 2
|
||||
world_size: int = -2 # unset state
|
||||
# rank of the current process
|
||||
rank: int = -1 # unset state
|
||||
rank: int = -2 # unset state
|
||||
# Rendezvous file
|
||||
rdvz_file: Optional[str] = None
|
||||
# timeout configured per class
|
||||
timeout: timedelta = timedelta(seconds=120)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def backend_str(cls) -> str:
|
||||
def backend_str(cls) -> Optional[str]:
|
||||
"""
|
||||
ProcessGroup backend str.
|
||||
To be customized by sub test classes, e.g. "nccl".
|
||||
Here we raise error.
|
||||
Otherwise we return None -- lazily decided by tensor.
|
||||
"""
|
||||
raise NotImplementedError("Please implement backend_str in your test class")
|
||||
return None
|
||||
|
||||
# Please override if you intend to test on specific device type
|
||||
@classmethod
|
||||
def device_type(cls) -> str:
|
||||
return torch.accelerator.current_accelerator().type
|
||||
|
||||
@classmethod
|
||||
def opts(cls, high_priority_stream=False):
|
||||
|
|
@ -1530,6 +1533,91 @@ class MultiProcContinousTest(TestCase):
|
|||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _init_pg(cls, rank, world_size, rdvz_file):
|
||||
assert rdvz_file is not None
|
||||
store = c10d.FileStore(rdvz_file, world_size)
|
||||
|
||||
# create nccl processgroup with opts
|
||||
c10d.init_process_group(
|
||||
backend=cls.backend_str(),
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
store=store,
|
||||
pg_options=cls.opts(),
|
||||
timeout=cls.timeout,
|
||||
)
|
||||
cls.pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
@classmethod
|
||||
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
|
||||
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
||||
test_name = test_id.split(".")[-1]
|
||||
# Get the test function from the test class
|
||||
self = cls(test_name)
|
||||
self.rank = cls.rank
|
||||
self.world_size = cls.world_size
|
||||
test_fn = getattr(self, test_name)
|
||||
# Run the test function
|
||||
test_fn(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
|
||||
# Sub tests are going to access these values, check first
|
||||
assert 0 <= rank < world_size
|
||||
# set class variables for the test class
|
||||
cls.rank = rank
|
||||
cls.world_size = world_size
|
||||
|
||||
# Initialize the process group
|
||||
cls._init_pg(rank, world_size, rdvz_file)
|
||||
|
||||
# End of bootstrap
|
||||
logger.info("Setup complete")
|
||||
|
||||
# Loop forever, waiting for a test name to run
|
||||
while True:
|
||||
test_id = task_queue.get()
|
||||
logger.debug(f"Got test {test_id}") # noqa: G004
|
||||
if test_id is None:
|
||||
break
|
||||
|
||||
cls._run_test_given_id(test_id)
|
||||
completion_queue.put(test_id)
|
||||
|
||||
# Termination
|
||||
logger.info("Terminating ...")
|
||||
c10d.destroy_process_group()
|
||||
|
||||
@classmethod
|
||||
def _spawn_processes(cls, world_size) -> None:
|
||||
cls.processes = []
|
||||
cls.task_queues = []
|
||||
cls.completion_queues = []
|
||||
# Need a rendezvous file for `init_process_group` purpose.
|
||||
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
|
||||
# CUDA multiprocessing requires spawn instead of fork, to make sure
|
||||
# child processes have their own memory space.
|
||||
if torch.multiprocessing.get_start_method(allow_none=True) != "spawn":
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
for rank in range(int(world_size)):
|
||||
task_queue = torch.multiprocessing.Queue()
|
||||
completion_queue = torch.multiprocessing.Queue()
|
||||
process = torch.multiprocessing.Process(
|
||||
target=cls._worker_loop,
|
||||
name="process " + str(rank),
|
||||
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
|
||||
)
|
||||
process.start()
|
||||
cls.processes.append(process)
|
||||
cls.task_queues.append(task_queue)
|
||||
cls.completion_queues.append(completion_queue)
|
||||
logger.info(
|
||||
"Started process %s with pid %s", rank, process.pid
|
||||
) # noqa: UP031
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
|
|
@ -1537,30 +1625,18 @@ class MultiProcContinousTest(TestCase):
|
|||
Set up the process group.
|
||||
"""
|
||||
super().setUpClass()
|
||||
if not 0 <= cls.rank < cls.world_size:
|
||||
raise RuntimeError(
|
||||
"Rank must be set and in the range of 0 to world_size. "
|
||||
f"World size: {cls.world_size} Rank: {cls.rank}"
|
||||
)
|
||||
if cls.rdvz_file:
|
||||
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
|
||||
else:
|
||||
# torchrun takes care of rendezvous
|
||||
store = None
|
||||
opts = cls.opts()
|
||||
backend = cls.backend_str()
|
||||
print(f"Testing {backend=}")
|
||||
# create nccl processgroup with opts
|
||||
c10d.init_process_group(
|
||||
backend=backend,
|
||||
world_size=cls.world_size,
|
||||
rank=cls.rank,
|
||||
store=store,
|
||||
pg_options=opts,
|
||||
timeout=cls.timeout,
|
||||
|
||||
# Use device count as world size
|
||||
device_type = cls.device_type()
|
||||
cls.world_size = torch.get_device_module(device_type).device_count()
|
||||
if cls.world_size == 0:
|
||||
raise unittest.SkipTest(f"No {device_type} devices available")
|
||||
|
||||
logger.info(
|
||||
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
|
||||
)
|
||||
cls.pg = c10d.distributed_c10d._get_default_group()
|
||||
print(f"Rank {cls.rank} setup complete")
|
||||
|
||||
cls._spawn_processes(cls.world_size)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
|
@ -1568,37 +1644,75 @@ class MultiProcContinousTest(TestCase):
|
|||
Class-scope test fixture. Run once for entire test class, after all tests finish.
|
||||
Tear down the process group.
|
||||
"""
|
||||
c10d.destroy_process_group()
|
||||
super().tearDownClass()
|
||||
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
|
||||
# Enqueue "None" to all workers to tell them to exit
|
||||
for task_queue in cls.task_queues:
|
||||
task_queue.put(None)
|
||||
|
||||
# Wait for all workers to exit
|
||||
for process in cls.processes:
|
||||
process.join()
|
||||
|
||||
# Clear up the rendezvous file
|
||||
if cls.rdvz_file:
|
||||
try:
|
||||
os.remove(cls.rdvz_file)
|
||||
except OSError:
|
||||
pass
|
||||
print(f"Rank {cls.rank} teardown complete")
|
||||
try:
|
||||
os.remove(cls.rdvz_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run_rank(
|
||||
cls,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
rdvz_file: Optional[str] = None,
|
||||
):
|
||||
logger.info(f"Class {cls.__name__} finished") # noqa: G004
|
||||
super().tearDownClass()
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
|
||||
In this entry point, we set the class variables for the test class.
|
||||
Then we run all tests.
|
||||
|
||||
Note:
|
||||
- This helper only works for a subclass of `MultiProcContinousTest`.
|
||||
|
||||
Example:
|
||||
- See `test_c10d_ops_nccl.py`.
|
||||
Test fixture. Run before each test.
|
||||
"""
|
||||
# set class variables for the test class
|
||||
cls.rank = rank
|
||||
cls.world_size = world_size
|
||||
cls.rdvz_file = rdvz_file
|
||||
# Launch tests via `common_utils` infra
|
||||
run_tests()
|
||||
super().setUp()
|
||||
|
||||
# I am the dispatcher
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
|
||||
# Enqueue "current test" to all workers
|
||||
for i, task_queue in enumerate(self.task_queues):
|
||||
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
|
||||
task_queue.put(self.id())
|
||||
|
||||
def _worker_run_main_wait(self, fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
|
||||
# Wait for the workers to finish the test
|
||||
for i, completion_queue in enumerate(self.completion_queues):
|
||||
test_id = completion_queue.get()
|
||||
assert test_id == self.id()
|
||||
logger.debug(
|
||||
f"Main proc detected rank {i} finished {test_id}" # noqa: G004
|
||||
)
|
||||
else:
|
||||
# Worker just runs the test
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
# The main process spawns N subprocesses that run the test.
|
||||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(
|
||||
self, method_name: str = "runTest", methodName: str = "runTest"
|
||||
) -> None:
|
||||
# methodName is the correct naming in unittest and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != "runTest":
|
||||
method_name = methodName
|
||||
super().__init__(method_name)
|
||||
try:
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self._worker_run_main_wait(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != "runTest":
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(
|
||||
f"no such test method in {self.__class__}: {methodName}"
|
||||
) from e
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user