mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25012 Resubmitting https://github.com/pytorch/pytorch/pull/22907 with build fix. This change adds the following functionality: 1) WorkNCCL isCompleted, isSuccess methods check for NCCL errors and set the appropriate exception. 2) Added a watchdog thread to ProcessGroupNCCL which checks for errors in the cached communicators and removes them from the cache. 3) Use ncclCommAbort in NCCLComm destructor since ncclCommDestroy can block forever waiting for work. 4) Added a simulate_nccl_errors.py script to simulate NCCL errors. https://github.com/pytorch/pytorch/issues/17882 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22907 Test Plan: 1) Run the simulate_nccl_errors.py to verify NCCL errors are caught. Differential Revision: D16958078 fbshipit-source-id: 662b0b8b8ee250e2b6d15bdfc9306d71c4f66219
170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import multiprocessing
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
|
|
from collections import namedtuple
|
|
from functools import wraps
|
|
|
|
import torch
|
|
import torch.distributed as c10d
|
|
|
|
from common_utils import TestCase
|
|
|
|
|
|
TestSkip = namedtuple('TestSkip', 'exit_code, message')
|
|
|
|
|
|
TEST_SKIPS = {
|
|
"multi-gpu": TestSkip(75, "Need at least 2 CUDA devices"),
|
|
"nccl": TestSkip(76, "c10d not compiled with NCCL support"),
|
|
"known_issues": TestSkip(77, "Test skipped due to known issues")
|
|
}
|
|
|
|
|
|
def skip_if_not_multigpu(func):
|
|
"""Multi-GPU tests requires at least 2 GPUS. Skip if this is not met."""
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
|
|
return func(*args, **kwargs)
|
|
sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
|
|
|
|
return wrapper
|
|
|
|
|
|
def skip_if_lt_x_gpu(x):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
|
return func(*args, **kwargs)
|
|
sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def skip_for_known_issues(func):
|
|
"""Skips a test due to known issues (for c10d)."""
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
sys.exit(TEST_SKIPS['known_issues'].exit_code)
|
|
|
|
return wrapper
|
|
|
|
|
|
def requires_gloo():
|
|
return unittest.skipUnless(
|
|
c10d.is_gloo_available(),
|
|
"c10d was not compiled with the Gloo backend",
|
|
)
|
|
|
|
|
|
def requires_nccl():
|
|
return unittest.skipUnless(
|
|
c10d.is_nccl_available(),
|
|
"c10d was not compiled with the NCCL backend",
|
|
)
|
|
|
|
|
|
def requires_mpi():
|
|
return unittest.skipUnless(
|
|
c10d.is_mpi_available(),
|
|
"c10d was not compiled with the MPI backend",
|
|
)
|
|
|
|
|
|
TIMEOUT_DEFAULT = 30
|
|
TIMEOUT_OVERRIDE = {}
|
|
|
|
|
|
def get_timeout(test_id):
|
|
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
|
|
|
|
|
|
class MultiProcessTestCase(TestCase):
|
|
MAIN_PROCESS_RANK = -1
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@staticmethod
|
|
def join_or_run(fn):
|
|
@wraps(fn)
|
|
def wrapper(self):
|
|
if self.rank == self.MAIN_PROCESS_RANK:
|
|
self._join_processes(fn)
|
|
else:
|
|
fn(self)
|
|
return wrapper
|
|
|
|
# The main process spawns N subprocesses that run the test.
|
|
# This function patches overwrites every test function to either
|
|
# assume the role of the main process and join its subprocesses,
|
|
# or run the underlying test function.
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
for attr in dir(cls):
|
|
if attr.startswith('test'):
|
|
fn = getattr(cls, attr)
|
|
setattr(cls, attr, cls.join_or_run(fn))
|
|
|
|
def setUp(self):
|
|
super(MultiProcessTestCase, self).setUp()
|
|
self.skip_return_code_checks = []
|
|
self.rank = self.MAIN_PROCESS_RANK
|
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
|
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
|
|
|
|
def tearDown(self):
|
|
super(MultiProcessTestCase, self).tearDown()
|
|
for p in self.processes:
|
|
p.terminate()
|
|
|
|
def _spawn_process(self, rank):
|
|
name = 'process ' + str(rank)
|
|
process = multiprocessing.Process(target=self._run, name=name, args=(rank,))
|
|
process.start()
|
|
return process
|
|
|
|
def _run(self, rank):
|
|
self.rank = rank
|
|
|
|
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
|
# We're retreiving a corresponding test and executing it.
|
|
getattr(self, self.id().split(".")[2])()
|
|
sys.exit(0)
|
|
|
|
def _join_processes(self, fn):
|
|
timeout = get_timeout(self.id())
|
|
start_time = time.time()
|
|
for p in self.processes:
|
|
p.join(timeout)
|
|
elapsed_time = time.time() - start_time
|
|
if fn in self.skip_return_code_checks:
|
|
self._check_return_codes(elapsed_time)
|
|
|
|
def _check_return_codes(self, elapsed_time):
|
|
"""
|
|
Checks that the return codes of all spawned processes match, and skips
|
|
tests if they returned a return code indicating a skipping condition.
|
|
"""
|
|
first_process = self.processes[0]
|
|
for i, p in enumerate(self.processes):
|
|
if p.exitcode is None:
|
|
raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time))
|
|
self.assertEqual(p.exitcode, first_process.exitcode)
|
|
for skip in TEST_SKIPS.values():
|
|
if first_process.exitcode == skip.exit_code:
|
|
raise unittest.SkipTest(skip.message)
|
|
self.assertEqual(first_process.exitcode, 0)
|
|
|
|
@property
|
|
def is_master(self):
|
|
return self.rank == 0
|