mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62051 The goal here is to enable opt-asan for "spawn" based unit tests since this works for "spawn" unlike "dev-asan". As a result, we can run ASAN for "spawn" unit tests as well. This means we can completely remove fork unit tests from the code base since the only purpose for these tests was to run ASAN. ghstack-source-id: 135523770 Test Plan: waitforbuildbot Reviewed By: SciPioneer Differential Revision: D29854514 fbshipit-source-id: 02a5bfcfae2afc21badecff77082c7a6ad83636b
511 lines
16 KiB
Python
511 lines
16 KiB
Python
import contextlib
|
|
import os
|
|
import sys
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
require_n_gpus_for_nccl_backend,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
|
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
|
|
|
# Constants used for testing post-hooks
|
|
BEFORE_CONSTANT = 41
|
|
AFTER_CONSTANT = 42
|
|
|
|
|
|
class AllReducerJoinHook(JoinHook):
|
|
r"""
|
|
Join hook for :class:`AllReducer`.
|
|
|
|
Arguments:
|
|
allreducer (AllReducer): the :class:`AllReducer` object using this
|
|
hook.
|
|
num_allreduces (int): the number of all-reduces to shadow per
|
|
iteration.
|
|
run_post_hook (bool): a flag enabling the post-hook logic.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
allreducer,
|
|
num_allreduces,
|
|
run_post_hook
|
|
):
|
|
self.allreducer = allreducer
|
|
self.num_allreduces = num_allreduces
|
|
self.run_post_hook = run_post_hook
|
|
|
|
def main_hook(self):
|
|
r"""
|
|
Shadows each all-reduce; the number of all-reduces is passed into the
|
|
constructor as ``num_allreduces``.
|
|
"""
|
|
device = self.allreducer.device
|
|
for _ in range(self.num_allreduces):
|
|
t = torch.zeros(1, device=device)
|
|
dist.all_reduce(t)
|
|
|
|
def post_hook(self, is_last_joiner: bool):
|
|
r"""
|
|
Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from
|
|
the last joiner to all other processes.
|
|
"""
|
|
if not self.run_post_hook:
|
|
return
|
|
rank = dist.get_rank(self.allreducer.process_group)
|
|
common_rank = self.allreducer.find_common_rank(rank, is_last_joiner)
|
|
device = self.allreducer.device
|
|
if rank == common_rank:
|
|
self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device)
|
|
dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank)
|
|
|
|
|
|
class AllReducer(Joinable):
|
|
r"""
|
|
Example :class:`Joinable` that performs some number of all-reduces as its
|
|
per-iteration collective communication.
|
|
"""
|
|
def __init__(self, device, process_group):
|
|
super(AllReducer, self).__init__()
|
|
self.device = device
|
|
self.process_group = process_group
|
|
self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device)
|
|
|
|
def __call__(self, num_allreduces=1):
|
|
r"""
|
|
All-reduces a dim-1 one tensor ``num_allreduces``-many times, and
|
|
returns the total result.
|
|
"""
|
|
Join.notify_join_context(self)
|
|
device = self.device
|
|
total = 0
|
|
for _ in range(num_allreduces):
|
|
t = torch.ones(1, device=device)
|
|
dist.all_reduce(t)
|
|
total += t.item()
|
|
return total
|
|
|
|
def join_hook(self, **kwargs) -> JoinHook:
|
|
r"""
|
|
Returns a join hook that shadows some number of all-reduces; by default,
|
|
this number is 1.
|
|
"""
|
|
num_allreduces = kwargs.get("num_allreduces", 1)
|
|
run_post_hook = kwargs.get("run_post_hooks", False)
|
|
return AllReducerJoinHook(
|
|
self,
|
|
num_allreduces,
|
|
run_post_hook
|
|
)
|
|
|
|
@property
|
|
def join_device(self) -> torch.device:
|
|
return self.device
|
|
|
|
@property
|
|
def join_process_group(self) -> Any:
|
|
return self.process_group
|
|
|
|
def find_common_rank(self, rank, to_consider):
|
|
r"""
|
|
Returns the max rank of the ones to consider over the process group.
|
|
"""
|
|
common_rank = torch.tensor(
|
|
[rank if to_consider else -1],
|
|
device=self.device
|
|
)
|
|
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
|
|
common_rank = common_rank.item()
|
|
assert common_rank >= 0
|
|
return common_rank
|
|
|
|
class TestJoin(MultiProcessTestCase):
|
|
r"""Test cases for the generic join context."""
|
|
def setUp(self):
|
|
super(TestJoin, self).setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = BACKEND
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
|
|
else torch.device("cpu")
|
|
|
|
@property
|
|
def world_size(self):
|
|
return WORLD_SIZE
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def tearDown(self):
|
|
try:
|
|
dist.destroy_process_group()
|
|
except AssertionError:
|
|
pass
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def dist_init(self, rank, world_size, backend=BACKEND):
|
|
store = dist.FileStore(self.file_name, world_size)
|
|
return dist.init_process_group(
|
|
backend=backend,
|
|
store=store,
|
|
rank=rank,
|
|
world_size=world_size
|
|
)
|
|
|
|
def construct_uneven_inputs(self, base, offset, device=None):
|
|
r"""
|
|
Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs.
|
|
"""
|
|
if device is None:
|
|
device = self.device
|
|
return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)]
|
|
|
|
def construct_even_inputs(self, base, device=None):
|
|
r"""Returns even inputs: each rank gets ``base`` inputs."""
|
|
if device is None:
|
|
device = self.device
|
|
return [torch.zeros(1, device=device) for _ in range(base)]
|
|
|
|
@property
|
|
def base_num_inputs(self):
|
|
r"""Base number of inputs to be used by all ranks."""
|
|
return 3
|
|
|
|
@property
|
|
def offset(self):
|
|
r"""Rank i gets i * ``offset`` additional inputs."""
|
|
return 1
|
|
|
|
def _test_join_base(
|
|
self,
|
|
uneven_inputs: bool,
|
|
num_joinables: int,
|
|
enable: bool,
|
|
throw_on_early_termination: bool,
|
|
num_allreduces: int,
|
|
run_post_hooks: bool,
|
|
expected_total: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Skeleton for all :class:`Join` tests.
|
|
|
|
Arguments:
|
|
uneven_inputs (bool): ``True`` to use uneven inputs; ``False``
|
|
otherwise.
|
|
num_joinables (int): number of :class:`AllReducer` s to construct.
|
|
enable (bool): ``True`` to enable the join context manager;
|
|
``False`` otherwise.
|
|
throw_on_early_termination (bool): ``True`` to raise an exception
|
|
upon detecting uneven inputs; ``False`` otherwise.
|
|
num_allreduces (int): number of all-reduces to perform per input.
|
|
run_post_hooks (bool): ``True`` to run post-hooks; ``False``
|
|
otherwise.
|
|
expected_total (Optional[int]): ``None`` to not check the expected
|
|
all-reduce total; otherwise, the expected total; default is
|
|
``None``.
|
|
"""
|
|
self.dist_init(self.rank, self.world_size)
|
|
|
|
allreducers = [
|
|
AllReducer(self.device, self.process_group)
|
|
for _ in range(num_joinables)
|
|
]
|
|
for allreducer in allreducers:
|
|
self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT)
|
|
|
|
inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \
|
|
if uneven_inputs \
|
|
else self.construct_even_inputs(self.base_num_inputs)
|
|
allreduce_total = 0
|
|
|
|
# Expect a `RuntimeError` if `throw_on_early_termination=True`
|
|
# Rank 0 exhausts its inputs first
|
|
expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \
|
|
else "Detected at least one rank that exhausted inputs. " \
|
|
"Throwing across all ranks."
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
expected_msg
|
|
) if throw_on_early_termination else contextlib.suppress():
|
|
with Join(
|
|
allreducers,
|
|
enable=enable,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks
|
|
):
|
|
for _ in inputs:
|
|
for allreducer in allreducers:
|
|
allreduce_total += allreducer(num_allreduces)
|
|
|
|
if throw_on_early_termination:
|
|
return
|
|
|
|
# Check `expected_total` if not `None`
|
|
if expected_total:
|
|
self.assertEqual(allreduce_total, expected_total)
|
|
|
|
# All `AllReduce` instances should receive the updated
|
|
# `post_hook_tensor` from the last-joined process
|
|
if run_post_hooks:
|
|
for allreducer in allreducers:
|
|
self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_main_hooks(self):
|
|
r"""Tests the main hooks of a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
run_post_hooks = False
|
|
# Non-joined processes all-reduce a 1, so this rank's all-reduce total
|
|
# should be precisely equal to the total number of inputs processed
|
|
# before it joined
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
# Rank i runs for i additional iterations
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_post_hooks(self):
|
|
r"""Tests the post-hooks of a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 0 # set to 0 to skip the main hooks
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable(self):
|
|
r"""
|
|
Tests the main hooks and post-hooks of a single :class:`Joinable`
|
|
together.
|
|
|
|
This combines ``test_single_joinable_main_hooks()`` and
|
|
``test_single_joinable_post_hooks()`` into a single test to ensure that
|
|
main hooks and post-hooks operate correctly together.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
run_post_hooks = True
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinables(self):
|
|
r"""
|
|
Tests the main hooks and post-hooks of multiple :class:`Joinable` s
|
|
together.
|
|
|
|
This generalizes ``test_single_joinable()`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
run_post_hooks = True
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
# The expected total is now multiplied by a factor of `NUM_JOINABLES`
|
|
expected_total *= num_joinables
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_disable(self):
|
|
r"""Tests ``enable=False`` for a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
uneven_inputs = False
|
|
enable = False
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=uneven_inputs,
|
|
num_joinables=num_joinables,
|
|
enable=enable,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinable_disable(self):
|
|
r"""
|
|
Tests ``enable=False`` for multiple :class:`Joinable` s.
|
|
|
|
This generalizes ``test_single_joinable_disable`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
uneven_inputs = False
|
|
enable = False
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs * num_joinables
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=uneven_inputs,
|
|
num_joinables=num_joinables,
|
|
enable=enable,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_throw(self):
|
|
r"""
|
|
Tests ``throw_on_early_termination=True`` for a single
|
|
:class:`Joinable`.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
throw_on_early_termination = True
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinables_throw(self):
|
|
r"""
|
|
Tests ``throw_on_early_termination=True`` for multiple
|
|
:class:`Joinable` s together.
|
|
|
|
This generalizes ``test_single_joinable_throw`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
throw_on_early_termination = True
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_join_kwargs(self):
|
|
r"""
|
|
Tests passing keyword arguments to the context manager.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 2
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
# The expected total is now multiplied by a factor of `NUM_ALLREDUCES`
|
|
expected_total *= num_allreduces
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|