mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: MultiProcessTestCase will be useful for both c10d and rpc tests. So, this diff extracts that class and some common decorators to a separate file. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23660 Reviewed By: pietern Differential Revision: D16602865 Pulled By: mrshenli fbshipit-source-id: 85ad47dfb8ba187b7debeb3edeea5df08ef690c7
166 lines
4.7 KiB
Python
166 lines
4.7 KiB
Python
import multiprocessing
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
|
|
from collections import namedtuple
|
|
from functools import wraps
|
|
|
|
import torch
|
|
import torch.distributed as c10d
|
|
|
|
from common_utils import TestCase
|
|
|
|
|
|
TestSkip = namedtuple('TestSkip', 'exit_code, message')
|
|
|
|
|
|
TEST_SKIPS = {
|
|
"multi-gpu": TestSkip(75, "Need at least 2 CUDA devices"),
|
|
"nccl": TestSkip(76, "c10d not compiled with NCCL support"),
|
|
"known_issues": TestSkip(77, "Test skipped due to known issues")
|
|
}
|
|
|
|
|
|
def skip_if_not_multigpu(func):
|
|
"""Multi-GPU tests requires at least 2 GPUS. Skip if this is not met."""
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
|
|
return func(*args, **kwargs)
|
|
sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
|
|
|
|
return wrapper
|
|
|
|
|
|
def skip_if_lt_x_gpu(x):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
|
return func(*args, **kwargs)
|
|
sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def skip_for_known_issues(func):
|
|
"""Skips a test due to known issues (for c10d)."""
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
sys.exit(TEST_SKIPS['known_issues'].exit_code)
|
|
|
|
return wrapper
|
|
|
|
|
|
def requires_gloo():
|
|
return unittest.skipUnless(
|
|
c10d.is_gloo_available(),
|
|
"c10d was not compiled with the Gloo backend",
|
|
)
|
|
|
|
|
|
def requires_nccl():
|
|
return unittest.skipUnless(
|
|
c10d.is_nccl_available(),
|
|
"c10d was not compiled with the NCCL backend",
|
|
)
|
|
|
|
|
|
def requires_mpi():
|
|
return unittest.skipUnless(
|
|
c10d.is_mpi_available(),
|
|
"c10d was not compiled with the MPI backend",
|
|
)
|
|
|
|
|
|
TIMEOUT_DEFAULT = 30
|
|
TIMEOUT_OVERRIDE = {}
|
|
|
|
|
|
def get_timeout(test_id):
|
|
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
|
|
|
|
|
|
class MultiProcessTestCase(TestCase):
|
|
MAIN_PROCESS_RANK = -1
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@staticmethod
|
|
def join_or_run(fn):
|
|
@wraps(fn)
|
|
def wrapper(self):
|
|
if self.rank == self.MAIN_PROCESS_RANK:
|
|
self._join_processes(fn)
|
|
else:
|
|
fn(self)
|
|
return wrapper
|
|
|
|
# The main process spawns N subprocesses that run the test.
|
|
# This function patches overwrites every test function to either
|
|
# assume the role of the main process and join its subprocesses,
|
|
# or run the underlying test function.
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
for attr in dir(cls):
|
|
if attr.startswith('test'):
|
|
fn = getattr(cls, attr)
|
|
setattr(cls, attr, cls.join_or_run(fn))
|
|
|
|
def setUp(self):
|
|
super(MultiProcessTestCase, self).setUp()
|
|
self.rank = self.MAIN_PROCESS_RANK
|
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
|
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
|
|
|
|
def tearDown(self):
|
|
super(MultiProcessTestCase, self).tearDown()
|
|
for p in self.processes:
|
|
p.terminate()
|
|
|
|
def _spawn_process(self, rank):
|
|
name = 'process ' + str(rank)
|
|
process = multiprocessing.Process(target=self._run, name=name, args=(rank,))
|
|
process.start()
|
|
return process
|
|
|
|
def _run(self, rank):
|
|
self.rank = rank
|
|
|
|
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
|
# We're retreiving a corresponding test and executing it.
|
|
getattr(self, self.id().split(".")[2])()
|
|
sys.exit(0)
|
|
|
|
def _join_processes(self, fn):
|
|
timeout = get_timeout(self.id())
|
|
start_time = time.time()
|
|
for p in self.processes:
|
|
p.join(timeout)
|
|
elapsed_time = time.time() - start_time
|
|
self._check_return_codes(elapsed_time)
|
|
|
|
def _check_return_codes(self, elapsed_time):
|
|
"""
|
|
Checks that the return codes of all spawned processes match, and skips
|
|
tests if they returned a return code indicating a skipping condition.
|
|
"""
|
|
first_process = self.processes[0]
|
|
for i, p in enumerate(self.processes):
|
|
if p.exitcode is None:
|
|
raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time))
|
|
self.assertEqual(p.exitcode, first_process.exitcode)
|
|
for skip in TEST_SKIPS.values():
|
|
if first_process.exitcode == skip.exit_code:
|
|
raise unittest.SkipTest(skip.message)
|
|
self.assertEqual(first_process.exitcode, 0)
|
|
|
|
@property
|
|
def is_master(self):
|
|
return self.rank == 0
|