Revert "[Distributed][CI] Rework continuous TestCase (#153653)"

This reverts commit 0d5c628a6e.

Reverted https://github.com/pytorch/pytorch/pull/153653 on behalf of https://github.com/kwen2501 due to More fixes needed ([comment](https://github.com/pytorch/pytorch/pull/153653#issuecomment-2891931028))
This commit is contained in:
PyTorch MergeBot 2025-05-19 18:29:26 +00:00
parent 0d5c628a6e
commit 674a85cf26
4 changed files with 121 additions and 201 deletions

View File

@ -1,16 +0,0 @@
# Owner(s): ["oncall: distributed"]
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import run_tests
class TestTemplate(MultiProcContinousTest):
def testABC(self):
print(f"rank {self.rank} of {self.world_size} testing ABC")
def testDEF(self):
print(f"rank {self.rank} of {self.world_size} testing DEF")
if __name__ == "__main__":
run_tests()

View File

@ -11,6 +11,7 @@
import math
import os
import sys
import tempfile
import torch
import torch.distributed as c10d
@ -29,9 +30,9 @@ from torch.testing._internal.common_distributed import (
requires_nccl,
requires_nccl_version,
sm_is_or_higher_than,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
@ -1043,4 +1044,24 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
if __name__ == "__main__":
run_tests()
if not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", -1))
if world_size == -1: # Not set by external launcher
world_size = torch.cuda.device_count()
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
ProcessGroupNCCLOpTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)

View File

@ -7,13 +7,16 @@
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
@ -44,20 +47,28 @@ device_module = torch.get_device_module(device_type)
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def _init_device(self) -> None:
def setUp(self) -> None:
super().setUp()
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)
# Required by MultiProcContinousTest
@classmethod
def backend_str(cls) -> str:
return "nccl"
@property
def world_size(self) -> int:
return device_module.device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -81,8 +92,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
@skipIfRocm
def test_nvshmem_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -130,4 +139,24 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
if __name__ == "__main__":
run_tests()
if not device_module.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
# If launched by torchrun, these values would have been set
rank = int(os.getenv("RANK", "-1"))
world_size = int(os.getenv("WORLD_SIZE", "-1"))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
else:
# No external launcher, spawn N processes
world_size = device_module.device_count()
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
NVSHMEMSymmetricMemoryTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)

View File

@ -1,5 +1,6 @@
# mypy: ignore-errors
import abc
import faulthandler
import itertools
import logging
@ -37,6 +38,7 @@ from torch.testing._internal.common_utils import (
find_free_port,
IS_SANDCASTLE,
retry_on_connect_failures,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_HPU,
@ -1500,29 +1502,24 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
class MultiProcContinousTest(TestCase):
# Class variables:
MAIN_PROCESS_RANK = -1
# number of test processes
world_size: int = -2 # unset state
world_size: int = 2
# rank of the current process
rank: int = -2 # unset state
rank: int = -1 # unset state
# Rendezvous file
rdvz_file: Optional[str] = None
# timeout configured per class
timeout: timedelta = timedelta(seconds=120)
@classmethod
def backend_str(cls) -> Optional[str]:
@abc.abstractmethod
def backend_str(cls) -> str:
"""
ProcessGroup backend str.
To be customized by sub test classes, e.g. "nccl".
Otherwise we return None -- lazily decided by tensor.
Here we raise error.
"""
return None
# Please override if you intend to test on specific device type
@classmethod
def device_type(cls) -> str:
return torch.accelerator.current_accelerator().type
raise NotImplementedError("Please implement backend_str in your test class")
@classmethod
def opts(cls, high_priority_stream=False):
@ -1533,91 +1530,6 @@ class MultiProcContinousTest(TestCase):
"""
return None
@classmethod
def _init_pg(cls, rank, world_size, rdvz_file):
assert rdvz_file is not None
store = c10d.FileStore(rdvz_file, world_size)
# create nccl processgroup with opts
c10d.init_process_group(
backend=cls.backend_str(),
world_size=world_size,
rank=rank,
store=store,
pg_options=cls.opts(),
timeout=cls.timeout,
)
cls.pg = c10d.distributed_c10d._get_default_group()
@classmethod
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
test_name = test_id.split(".")[-1]
# Get the test function from the test class
self = cls(test_name)
self.rank = cls.rank
self.world_size = cls.world_size
test_fn = getattr(self, test_name)
# Run the test function
test_fn(**kwargs)
@classmethod
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
# Sub tests are going to access these values, check first
assert 0 <= rank < world_size
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
# Initialize the process group
cls._init_pg(rank, world_size, rdvz_file)
# End of bootstrap
logger.info("Setup complete")
# Loop forever, waiting for a test name to run
while True:
test_id = task_queue.get()
logger.debug(f"Got test {test_id}") # noqa: G004
if test_id is None:
break
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
# Termination
logger.info("Terminating ...")
c10d.destroy_process_group()
@classmethod
def _spawn_processes(cls, world_size) -> None:
cls.processes = []
cls.task_queues = []
cls.completion_queues = []
# Need a rendezvous file for `init_process_group` purpose.
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
# CUDA multiprocessing requires spawn instead of fork, to make sure
# child processes have their own memory space.
if torch.multiprocessing.get_start_method(allow_none=True) != "spawn":
torch.multiprocessing.set_start_method("spawn")
for rank in range(int(world_size)):
task_queue = torch.multiprocessing.Queue()
completion_queue = torch.multiprocessing.Queue()
process = torch.multiprocessing.Process(
target=cls._worker_loop,
name="process " + str(rank),
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
)
process.start()
cls.processes.append(process)
cls.task_queues.append(task_queue)
cls.completion_queues.append(completion_queue)
logger.info(
"Started process %s with pid %s", rank, process.pid
) # noqa: UP031
@classmethod
def setUpClass(cls):
"""
@ -1625,18 +1537,30 @@ class MultiProcContinousTest(TestCase):
Set up the process group.
"""
super().setUpClass()
# Use device count as world size
device_type = cls.device_type()
cls.world_size = torch.get_device_module(device_type).device_count()
if cls.world_size == 0:
raise unittest.SkipTest(f"No {device_type} devices available")
logger.info(
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
if not 0 <= cls.rank < cls.world_size:
raise RuntimeError(
"Rank must be set and in the range of 0 to world_size. "
f"World size: {cls.world_size} Rank: {cls.rank}"
)
if cls.rdvz_file:
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
else:
# torchrun takes care of rendezvous
store = None
opts = cls.opts()
backend = cls.backend_str()
print(f"Testing {backend=}")
# create nccl processgroup with opts
c10d.init_process_group(
backend=backend,
world_size=cls.world_size,
rank=cls.rank,
store=store,
pg_options=opts,
timeout=cls.timeout,
)
cls._spawn_processes(cls.world_size)
cls.pg = c10d.distributed_c10d._get_default_group()
print(f"Rank {cls.rank} setup complete")
@classmethod
def tearDownClass(cls):
@ -1644,75 +1568,37 @@ class MultiProcContinousTest(TestCase):
Class-scope test fixture. Run once for entire test class, after all tests finish.
Tear down the process group.
"""
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
# Enqueue "None" to all workers to tell them to exit
for task_queue in cls.task_queues:
task_queue.put(None)
# Wait for all workers to exit
for process in cls.processes:
process.join()
# Clear up the rendezvous file
try:
os.remove(cls.rdvz_file)
except OSError:
pass
logger.info(f"Class {cls.__name__} finished") # noqa: G004
c10d.destroy_process_group()
super().tearDownClass()
# Clear up the rendezvous file
if cls.rdvz_file:
try:
os.remove(cls.rdvz_file)
except OSError:
pass
print(f"Rank {cls.rank} teardown complete")
def setUp(self) -> None:
@classmethod
def run_rank(
cls,
rank: int,
world_size: int,
rdvz_file: Optional[str] = None,
):
"""
Test fixture. Run before each test.
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
In this entry point, we set the class variables for the test class.
Then we run all tests.
Note:
- This helper only works for a subclass of `MultiProcContinousTest`.
Example:
- See `test_c10d_ops_nccl.py`.
"""
super().setUp()
# I am the dispatcher
self.rank = self.MAIN_PROCESS_RANK
# Enqueue "current test" to all workers
for i, task_queue in enumerate(self.task_queues):
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
task_queue.put(self.id())
def _worker_run_main_wait(self, fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
test_id = completion_queue.get()
assert test_id == self.id()
logger.debug(
f"Main proc detected rank {i} finished {test_id}" # noqa: G004
)
else:
# Worker just runs the test
fn()
return types.MethodType(wrapper, self)
# The main process spawns N subprocesses that run the test.
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(
self, method_name: str = "runTest", methodName: str = "runTest"
) -> None:
# methodName is the correct naming in unittest and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != "runTest":
method_name = methodName
super().__init__(method_name)
try:
fn = getattr(self, method_name)
setattr(self, method_name, self._worker_run_main_wait(fn))
except AttributeError as e:
if methodName != "runTest":
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f"no such test method in {self.__class__}: {methodName}"
) from e
# set class variables for the test class
cls.rank = rank
cls.world_size = world_size
cls.rdvz_file = rdvz_file
# Launch tests via `common_utils` infra
run_tests()