pytorch/test/distributed/algorithms/test_join.py
Pritam Damania f7611b31aa [4/N] Enable opt-asan for distributed unit tests. (#62051)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62051

The goal here is to enable opt-asan for "spawn" based unit tests since
this works for "spawn" unlike "dev-asan". As a result, we can run ASAN for
"spawn" unit tests as well.

This means we can completely remove fork unit tests from the code base since
the only purpose for these tests was to run ASAN.
ghstack-source-id: 135523770

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D29854514

fbshipit-source-id: 02a5bfcfae2afc21badecff77082c7a6ad83636b
2021-08-10 22:38:31 -07:00

511 lines
16 KiB
Python

import contextlib
import os
import sys
from typing import Any, Optional
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
require_n_gpus_for_nccl_backend,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
# Constants used for testing post-hooks
BEFORE_CONSTANT = 41
AFTER_CONSTANT = 42
class AllReducerJoinHook(JoinHook):
r"""
Join hook for :class:`AllReducer`.
Arguments:
allreducer (AllReducer): the :class:`AllReducer` object using this
hook.
num_allreduces (int): the number of all-reduces to shadow per
iteration.
run_post_hook (bool): a flag enabling the post-hook logic.
"""
def __init__(
self,
allreducer,
num_allreduces,
run_post_hook
):
self.allreducer = allreducer
self.num_allreduces = num_allreduces
self.run_post_hook = run_post_hook
def main_hook(self):
r"""
Shadows each all-reduce; the number of all-reduces is passed into the
constructor as ``num_allreduces``.
"""
device = self.allreducer.device
for _ in range(self.num_allreduces):
t = torch.zeros(1, device=device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from
the last joiner to all other processes.
"""
if not self.run_post_hook:
return
rank = dist.get_rank(self.allreducer.process_group)
common_rank = self.allreducer.find_common_rank(rank, is_last_joiner)
device = self.allreducer.device
if rank == common_rank:
self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device)
dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank)
class AllReducer(Joinable):
r"""
Example :class:`Joinable` that performs some number of all-reduces as its
per-iteration collective communication.
"""
def __init__(self, device, process_group):
super(AllReducer, self).__init__()
self.device = device
self.process_group = process_group
self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device)
def __call__(self, num_allreduces=1):
r"""
All-reduces a dim-1 one tensor ``num_allreduces``-many times, and
returns the total result.
"""
Join.notify_join_context(self)
device = self.device
total = 0
for _ in range(num_allreduces):
t = torch.ones(1, device=device)
dist.all_reduce(t)
total += t.item()
return total
def join_hook(self, **kwargs) -> JoinHook:
r"""
Returns a join hook that shadows some number of all-reduces; by default,
this number is 1.
"""
num_allreduces = kwargs.get("num_allreduces", 1)
run_post_hook = kwargs.get("run_post_hooks", False)
return AllReducerJoinHook(
self,
num_allreduces,
run_post_hook
)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self) -> Any:
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor(
[rank if to_consider else -1],
device=self.device
)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
assert common_rank >= 0
return common_rank
class TestJoin(MultiProcessTestCase):
r"""Test cases for the generic join context."""
def setUp(self):
super(TestJoin, self).setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
else torch.device("cpu")
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def tearDown(self):
try:
dist.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
pass
def dist_init(self, rank, world_size, backend=BACKEND):
store = dist.FileStore(self.file_name, world_size)
return dist.init_process_group(
backend=backend,
store=store,
rank=rank,
world_size=world_size
)
def construct_uneven_inputs(self, base, offset, device=None):
r"""
Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs.
"""
if device is None:
device = self.device
return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)]
def construct_even_inputs(self, base, device=None):
r"""Returns even inputs: each rank gets ``base`` inputs."""
if device is None:
device = self.device
return [torch.zeros(1, device=device) for _ in range(base)]
@property
def base_num_inputs(self):
r"""Base number of inputs to be used by all ranks."""
return 3
@property
def offset(self):
r"""Rank i gets i * ``offset`` additional inputs."""
return 1
def _test_join_base(
self,
uneven_inputs: bool,
num_joinables: int,
enable: bool,
throw_on_early_termination: bool,
num_allreduces: int,
run_post_hooks: bool,
expected_total: Optional[int] = None,
):
r"""
Skeleton for all :class:`Join` tests.
Arguments:
uneven_inputs (bool): ``True`` to use uneven inputs; ``False``
otherwise.
num_joinables (int): number of :class:`AllReducer` s to construct.
enable (bool): ``True`` to enable the join context manager;
``False`` otherwise.
throw_on_early_termination (bool): ``True`` to raise an exception
upon detecting uneven inputs; ``False`` otherwise.
num_allreduces (int): number of all-reduces to perform per input.
run_post_hooks (bool): ``True`` to run post-hooks; ``False``
otherwise.
expected_total (Optional[int]): ``None`` to not check the expected
all-reduce total; otherwise, the expected total; default is
``None``.
"""
self.dist_init(self.rank, self.world_size)
allreducers = [
AllReducer(self.device, self.process_group)
for _ in range(num_joinables)
]
for allreducer in allreducers:
self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT)
inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \
if uneven_inputs \
else self.construct_even_inputs(self.base_num_inputs)
allreduce_total = 0
# Expect a `RuntimeError` if `throw_on_early_termination=True`
# Rank 0 exhausts its inputs first
expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \
else "Detected at least one rank that exhausted inputs. " \
"Throwing across all ranks."
with self.assertRaisesRegex(
RuntimeError,
expected_msg
) if throw_on_early_termination else contextlib.suppress():
with Join(
allreducers,
enable=enable,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks
):
for _ in inputs:
for allreducer in allreducers:
allreduce_total += allreducer(num_allreduces)
if throw_on_early_termination:
return
# Check `expected_total` if not `None`
if expected_total:
self.assertEqual(allreduce_total, expected_total)
# All `AllReduce` instances should receive the updated
# `post_hook_tensor` from the last-joined process
if run_post_hooks:
for allreducer in allreducers:
self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_main_hooks(self):
r"""Tests the main hooks of a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 1
run_post_hooks = False
# Non-joined processes all-reduce a 1, so this rank's all-reduce total
# should be precisely equal to the total number of inputs processed
# before it joined
expected_total = self.world_size * self.base_num_inputs
# Rank i runs for i additional iterations
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_post_hooks(self):
r"""Tests the post-hooks of a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 0 # set to 0 to skip the main hooks
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable(self):
r"""
Tests the main hooks and post-hooks of a single :class:`Joinable`
together.
This combines ``test_single_joinable_main_hooks()`` and
``test_single_joinable_post_hooks()`` into a single test to ensure that
main hooks and post-hooks operate correctly together.
"""
num_joinables = 1
num_allreduces = 1
run_post_hooks = True
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinables(self):
r"""
Tests the main hooks and post-hooks of multiple :class:`Joinable` s
together.
This generalizes ``test_single_joinable()`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
run_post_hooks = True
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
# The expected total is now multiplied by a factor of `NUM_JOINABLES`
expected_total *= num_joinables
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_disable(self):
r"""Tests ``enable=False`` for a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 1
uneven_inputs = False
enable = False
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs
self._test_join_base(
uneven_inputs=uneven_inputs,
num_joinables=num_joinables,
enable=enable,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinable_disable(self):
r"""
Tests ``enable=False`` for multiple :class:`Joinable` s.
This generalizes ``test_single_joinable_disable`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
uneven_inputs = False
enable = False
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs * num_joinables
self._test_join_base(
uneven_inputs=uneven_inputs,
num_joinables=num_joinables,
enable=enable,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_throw(self):
r"""
Tests ``throw_on_early_termination=True`` for a single
:class:`Joinable`.
"""
num_joinables = 1
num_allreduces = 1
throw_on_early_termination = True
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinables_throw(self):
r"""
Tests ``throw_on_early_termination=True`` for multiple
:class:`Joinable` s together.
This generalizes ``test_single_joinable_throw`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
throw_on_early_termination = True
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_join_kwargs(self):
r"""
Tests passing keyword arguments to the context manager.
"""
num_joinables = 1
num_allreduces = 2
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
# The expected total is now multiplied by a factor of `NUM_ALLREDUCES`
expected_total *= num_allreduces
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
if __name__ == "__main__":
run_tests()