mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Distributed][CI] Rework continuous TestCase (#153653)"
This reverts commit 0d5c628a6e.
Reverted https://github.com/pytorch/pytorch/pull/153653 on behalf of https://github.com/kwen2501 due to More fixes needed ([comment](https://github.com/pytorch/pytorch/pull/153653#issuecomment-2891931028))
This commit is contained in:
parent
0d5c628a6e
commit
674a85cf26
|
|
@ -1,16 +0,0 @@
|
||||||
# Owner(s): ["oncall: distributed"]
|
|
||||||
|
|
||||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
|
||||||
from torch.testing._internal.common_utils import run_tests
|
|
||||||
|
|
||||||
|
|
||||||
class TestTemplate(MultiProcContinousTest):
|
|
||||||
def testABC(self):
|
|
||||||
print(f"rank {self.rank} of {self.world_size} testing ABC")
|
|
||||||
|
|
||||||
def testDEF(self):
|
|
||||||
print(f"rank {self.rank} of {self.world_size} testing DEF")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_tests()
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as c10d
|
import torch.distributed as c10d
|
||||||
|
|
@ -29,9 +30,9 @@ from torch.testing._internal.common_distributed import (
|
||||||
requires_nccl,
|
requires_nccl,
|
||||||
requires_nccl_version,
|
requires_nccl_version,
|
||||||
sm_is_or_higher_than,
|
sm_is_or_higher_than,
|
||||||
|
TEST_SKIPS,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
TEST_WITH_DEV_DBG_ASAN,
|
TEST_WITH_DEV_DBG_ASAN,
|
||||||
|
|
@ -1043,4 +1044,24 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
if not torch.cuda.is_available():
|
||||||
|
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK", -1))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", -1))
|
||||||
|
|
||||||
|
if world_size == -1: # Not set by external launcher
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if rank != -1:
|
||||||
|
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||||
|
ProcessGroupNCCLOpTest.run_rank(rank, world_size)
|
||||||
|
else:
|
||||||
|
# Launched as a single process. Spawn subprocess to run the tests.
|
||||||
|
# Also need a rendezvous file for `init_process_group` purpose.
|
||||||
|
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||||
|
torch.multiprocessing.spawn(
|
||||||
|
ProcessGroupNCCLOpTest.run_rank,
|
||||||
|
nprocs=world_size,
|
||||||
|
args=(world_size, rdvz_file),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,16 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._symmetric_memory as symm_mem
|
import torch.distributed._symmetric_memory as symm_mem
|
||||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
from torch.testing._internal.common_distributed import (
|
||||||
|
MultiProcContinousTest,
|
||||||
|
TEST_SKIPS,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
)
|
)
|
||||||
|
|
@ -44,20 +47,28 @@ device_module = torch.get_device_module(device_type)
|
||||||
|
|
||||||
@requires_nvshmem()
|
@requires_nvshmem()
|
||||||
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
||||||
def _init_device(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
# TODO: relieve this (seems to hang if without)
|
# TODO: relieve this (seems to hang if without)
|
||||||
device_module.set_device(self.device)
|
device_module.set_device(self.device)
|
||||||
# NOTE: required for nvshmem allocation
|
# NOTE: required for nvshmem allocation
|
||||||
torch.empty(1, device=self.device)
|
torch.empty(1, device=self.device)
|
||||||
|
|
||||||
|
# Required by MultiProcContinousTest
|
||||||
|
@classmethod
|
||||||
|
def backend_str(cls) -> str:
|
||||||
|
return "nccl"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return device_module.device_count()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return torch.device(device_type, self.rank)
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_nvshmem_all_to_all(self) -> None:
|
def test_nvshmem_all_to_all(self) -> None:
|
||||||
self._init_device()
|
|
||||||
|
|
||||||
group_name = dist.group.WORLD.group_name
|
group_name = dist.group.WORLD.group_name
|
||||||
symm_mem.enable_symm_mem_for_group(group_name)
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
|
@ -81,8 +92,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_nvshmem_all_to_all_vdev(self) -> None:
|
def test_nvshmem_all_to_all_vdev(self) -> None:
|
||||||
self._init_device()
|
|
||||||
|
|
||||||
group_name = dist.group.WORLD.group_name
|
group_name = dist.group.WORLD.group_name
|
||||||
symm_mem.enable_symm_mem_for_group(group_name)
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
|
@ -130,4 +139,24 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
if not device_module.is_available():
|
||||||
|
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||||
|
|
||||||
|
# If launched by torchrun, these values would have been set
|
||||||
|
rank = int(os.getenv("RANK", "-1"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "-1"))
|
||||||
|
|
||||||
|
if rank != -1:
|
||||||
|
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||||
|
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
|
||||||
|
else:
|
||||||
|
# No external launcher, spawn N processes
|
||||||
|
world_size = device_module.device_count()
|
||||||
|
# Launched as a single process. Spawn subprocess to run the tests.
|
||||||
|
# Also need a rendezvous file for `init_process_group` purpose.
|
||||||
|
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||||
|
torch.multiprocessing.spawn(
|
||||||
|
NVSHMEMSymmetricMemoryTest.run_rank,
|
||||||
|
nprocs=world_size,
|
||||||
|
args=(world_size, rdvz_file),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
|
|
||||||
|
import abc
|
||||||
import faulthandler
|
import faulthandler
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -37,6 +38,7 @@ from torch.testing._internal.common_utils import (
|
||||||
find_free_port,
|
find_free_port,
|
||||||
IS_SANDCASTLE,
|
IS_SANDCASTLE,
|
||||||
retry_on_connect_failures,
|
retry_on_connect_failures,
|
||||||
|
run_tests,
|
||||||
skip_but_pass_in_sandcastle,
|
skip_but_pass_in_sandcastle,
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
|
|
@ -1500,29 +1502,24 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
||||||
|
|
||||||
class MultiProcContinousTest(TestCase):
|
class MultiProcContinousTest(TestCase):
|
||||||
# Class variables:
|
# Class variables:
|
||||||
MAIN_PROCESS_RANK = -1
|
|
||||||
# number of test processes
|
# number of test processes
|
||||||
world_size: int = -2 # unset state
|
world_size: int = 2
|
||||||
# rank of the current process
|
# rank of the current process
|
||||||
rank: int = -2 # unset state
|
rank: int = -1 # unset state
|
||||||
# Rendezvous file
|
# Rendezvous file
|
||||||
rdvz_file: Optional[str] = None
|
rdvz_file: Optional[str] = None
|
||||||
# timeout configured per class
|
# timeout configured per class
|
||||||
timeout: timedelta = timedelta(seconds=120)
|
timeout: timedelta = timedelta(seconds=120)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def backend_str(cls) -> Optional[str]:
|
@abc.abstractmethod
|
||||||
|
def backend_str(cls) -> str:
|
||||||
"""
|
"""
|
||||||
ProcessGroup backend str.
|
ProcessGroup backend str.
|
||||||
To be customized by sub test classes, e.g. "nccl".
|
To be customized by sub test classes, e.g. "nccl".
|
||||||
Otherwise we return None -- lazily decided by tensor.
|
Here we raise error.
|
||||||
"""
|
"""
|
||||||
return None
|
raise NotImplementedError("Please implement backend_str in your test class")
|
||||||
|
|
||||||
# Please override if you intend to test on specific device type
|
|
||||||
@classmethod
|
|
||||||
def device_type(cls) -> str:
|
|
||||||
return torch.accelerator.current_accelerator().type
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def opts(cls, high_priority_stream=False):
|
def opts(cls, high_priority_stream=False):
|
||||||
|
|
@ -1533,91 +1530,6 @@ class MultiProcContinousTest(TestCase):
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _init_pg(cls, rank, world_size, rdvz_file):
|
|
||||||
assert rdvz_file is not None
|
|
||||||
store = c10d.FileStore(rdvz_file, world_size)
|
|
||||||
|
|
||||||
# create nccl processgroup with opts
|
|
||||||
c10d.init_process_group(
|
|
||||||
backend=cls.backend_str(),
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
store=store,
|
|
||||||
pg_options=cls.opts(),
|
|
||||||
timeout=cls.timeout,
|
|
||||||
)
|
|
||||||
cls.pg = c10d.distributed_c10d._get_default_group()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
|
|
||||||
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
|
||||||
test_name = test_id.split(".")[-1]
|
|
||||||
# Get the test function from the test class
|
|
||||||
self = cls(test_name)
|
|
||||||
self.rank = cls.rank
|
|
||||||
self.world_size = cls.world_size
|
|
||||||
test_fn = getattr(self, test_name)
|
|
||||||
# Run the test function
|
|
||||||
test_fn(**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
|
|
||||||
# Sub tests are going to access these values, check first
|
|
||||||
assert 0 <= rank < world_size
|
|
||||||
# set class variables for the test class
|
|
||||||
cls.rank = rank
|
|
||||||
cls.world_size = world_size
|
|
||||||
|
|
||||||
# Initialize the process group
|
|
||||||
cls._init_pg(rank, world_size, rdvz_file)
|
|
||||||
|
|
||||||
# End of bootstrap
|
|
||||||
logger.info("Setup complete")
|
|
||||||
|
|
||||||
# Loop forever, waiting for a test name to run
|
|
||||||
while True:
|
|
||||||
test_id = task_queue.get()
|
|
||||||
logger.debug(f"Got test {test_id}") # noqa: G004
|
|
||||||
if test_id is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
cls._run_test_given_id(test_id)
|
|
||||||
completion_queue.put(test_id)
|
|
||||||
|
|
||||||
# Termination
|
|
||||||
logger.info("Terminating ...")
|
|
||||||
c10d.destroy_process_group()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _spawn_processes(cls, world_size) -> None:
|
|
||||||
cls.processes = []
|
|
||||||
cls.task_queues = []
|
|
||||||
cls.completion_queues = []
|
|
||||||
# Need a rendezvous file for `init_process_group` purpose.
|
|
||||||
cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
|
||||||
|
|
||||||
# CUDA multiprocessing requires spawn instead of fork, to make sure
|
|
||||||
# child processes have their own memory space.
|
|
||||||
if torch.multiprocessing.get_start_method(allow_none=True) != "spawn":
|
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
for rank in range(int(world_size)):
|
|
||||||
task_queue = torch.multiprocessing.Queue()
|
|
||||||
completion_queue = torch.multiprocessing.Queue()
|
|
||||||
process = torch.multiprocessing.Process(
|
|
||||||
target=cls._worker_loop,
|
|
||||||
name="process " + str(rank),
|
|
||||||
args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
|
|
||||||
)
|
|
||||||
process.start()
|
|
||||||
cls.processes.append(process)
|
|
||||||
cls.task_queues.append(task_queue)
|
|
||||||
cls.completion_queues.append(completion_queue)
|
|
||||||
logger.info(
|
|
||||||
"Started process %s with pid %s", rank, process.pid
|
|
||||||
) # noqa: UP031
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1625,18 +1537,30 @@ class MultiProcContinousTest(TestCase):
|
||||||
Set up the process group.
|
Set up the process group.
|
||||||
"""
|
"""
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
|
if not 0 <= cls.rank < cls.world_size:
|
||||||
# Use device count as world size
|
raise RuntimeError(
|
||||||
device_type = cls.device_type()
|
"Rank must be set and in the range of 0 to world_size. "
|
||||||
cls.world_size = torch.get_device_module(device_type).device_count()
|
f"World size: {cls.world_size} Rank: {cls.rank}"
|
||||||
if cls.world_size == 0:
|
)
|
||||||
raise unittest.SkipTest(f"No {device_type} devices available")
|
if cls.rdvz_file:
|
||||||
|
store = c10d.FileStore(cls.rdvz_file, cls.world_size)
|
||||||
logger.info(
|
else:
|
||||||
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004
|
# torchrun takes care of rendezvous
|
||||||
|
store = None
|
||||||
|
opts = cls.opts()
|
||||||
|
backend = cls.backend_str()
|
||||||
|
print(f"Testing {backend=}")
|
||||||
|
# create nccl processgroup with opts
|
||||||
|
c10d.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=cls.world_size,
|
||||||
|
rank=cls.rank,
|
||||||
|
store=store,
|
||||||
|
pg_options=opts,
|
||||||
|
timeout=cls.timeout,
|
||||||
)
|
)
|
||||||
|
cls.pg = c10d.distributed_c10d._get_default_group()
|
||||||
cls._spawn_processes(cls.world_size)
|
print(f"Rank {cls.rank} setup complete")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|
@ -1644,75 +1568,37 @@ class MultiProcContinousTest(TestCase):
|
||||||
Class-scope test fixture. Run once for entire test class, after all tests finish.
|
Class-scope test fixture. Run once for entire test class, after all tests finish.
|
||||||
Tear down the process group.
|
Tear down the process group.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Joining {cls.world_size} workers") # noqa: G004
|
c10d.destroy_process_group()
|
||||||
# Enqueue "None" to all workers to tell them to exit
|
|
||||||
for task_queue in cls.task_queues:
|
|
||||||
task_queue.put(None)
|
|
||||||
|
|
||||||
# Wait for all workers to exit
|
|
||||||
for process in cls.processes:
|
|
||||||
process.join()
|
|
||||||
|
|
||||||
# Clear up the rendezvous file
|
|
||||||
try:
|
|
||||||
os.remove(cls.rdvz_file)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(f"Class {cls.__name__} finished") # noqa: G004
|
|
||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
|
# Clear up the rendezvous file
|
||||||
|
if cls.rdvz_file:
|
||||||
|
try:
|
||||||
|
os.remove(cls.rdvz_file)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
print(f"Rank {cls.rank} teardown complete")
|
||||||
|
|
||||||
def setUp(self) -> None:
|
@classmethod
|
||||||
|
def run_rank(
|
||||||
|
cls,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
rdvz_file: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Test fixture. Run before each test.
|
This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
|
||||||
|
In this entry point, we set the class variables for the test class.
|
||||||
|
Then we run all tests.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This helper only works for a subclass of `MultiProcContinousTest`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- See `test_c10d_ops_nccl.py`.
|
||||||
"""
|
"""
|
||||||
super().setUp()
|
# set class variables for the test class
|
||||||
|
cls.rank = rank
|
||||||
# I am the dispatcher
|
cls.world_size = world_size
|
||||||
self.rank = self.MAIN_PROCESS_RANK
|
cls.rdvz_file = rdvz_file
|
||||||
|
# Launch tests via `common_utils` infra
|
||||||
# Enqueue "current test" to all workers
|
run_tests()
|
||||||
for i, task_queue in enumerate(self.task_queues):
|
|
||||||
logger.debug(f"Sending Rank {i}: {self.id()}") # noqa: G004
|
|
||||||
task_queue.put(self.id())
|
|
||||||
|
|
||||||
def _worker_run_main_wait(self, fn):
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(self):
|
|
||||||
if self.rank == self.MAIN_PROCESS_RANK:
|
|
||||||
logger.debug(f"Waiting for workers to finish {self.id()}") # noqa: G004
|
|
||||||
# Wait for the workers to finish the test
|
|
||||||
for i, completion_queue in enumerate(self.completion_queues):
|
|
||||||
test_id = completion_queue.get()
|
|
||||||
assert test_id == self.id()
|
|
||||||
logger.debug(
|
|
||||||
f"Main proc detected rank {i} finished {test_id}" # noqa: G004
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Worker just runs the test
|
|
||||||
fn()
|
|
||||||
|
|
||||||
return types.MethodType(wrapper, self)
|
|
||||||
|
|
||||||
# The main process spawns N subprocesses that run the test.
|
|
||||||
# Constructor patches current instance test method to
|
|
||||||
# assume the role of the main process and join its subprocesses,
|
|
||||||
# or run the underlying test function.
|
|
||||||
def __init__(
|
|
||||||
self, method_name: str = "runTest", methodName: str = "runTest"
|
|
||||||
) -> None:
|
|
||||||
# methodName is the correct naming in unittest and testslide uses keyword arguments.
|
|
||||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
|
||||||
if methodName != "runTest":
|
|
||||||
method_name = methodName
|
|
||||||
super().__init__(method_name)
|
|
||||||
try:
|
|
||||||
fn = getattr(self, method_name)
|
|
||||||
setattr(self, method_name, self._worker_run_main_wait(fn))
|
|
||||||
except AttributeError as e:
|
|
||||||
if methodName != "runTest":
|
|
||||||
# we allow instantiation with no explicit method name
|
|
||||||
# but not an *incorrect* or missing method name
|
|
||||||
raise ValueError(
|
|
||||||
f"no such test method in {self.__class__}: {methodName}"
|
|
||||||
) from e
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user