pytorch/test/distributed/test_distributed_spawn.py
Rohan Varma b22abbe381 Enable test_distributed to work with spawn mode (#41769)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41769

Currently the tests in `test_distributed` only work with the `fork` mode multiprocessing, this PR introduces support for `spawn` mode multiprocessing as well (while keeping the `fork` mode intact).

Motivations for the change:
1) Spawn multiprocessing is the default on MacOS, so it better emulates how MacOS users would use distributed
2) With python 3.8+, spawn is the default on linux, so we should have test coverage for this
3) PT multiprocessing suggests using spawn/forkserver over fork, for sharing cuda tensors: https://pytorch.org/docs/stable/multiprocessing.html
4) Spawn is better supported with respect to certain sanitizers such as TSAN, so adding this sanitizer coverage may help us uncover issues.

How it is done:
1) Move `test_distributed` tests in `_DistTestBase` class to a shared file `distributed_test` (similar to how the RPC tests are structured)
2) For `Barrier`, refactor the setup of temp directories, as the current version did not work with spawn, each process would get a different randomly generated directory and thus would write to different barriers.
3) Add all the relevant builds to run internally and in OSS.
Running test_distributed with spawn mode in OSS can be done with:
`python test/run_test.py -i distributed/test_distributed_spawn -v`

Reviewed By: izdeby

Differential Revision: D22408023

fbshipit-source-id: e206be16961fd80438f995e221f18139d7e6d2a9
2020-09-08 23:11:12 -07:00

36 lines
992 B
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import unittest
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN, NO_MULTIPROCESSING_SPAWN
from torch.testing._internal.distributed.distributed_test import (
DistributedTest, TestDistBackend
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
BACKEND = os.environ["BACKEND"]
if BACKEND == "gloo" or BACKEND == "nccl":
@unittest.skipIf(
TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
)
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN, "Spawn not available, skipping tests."
)
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
def setUp(self):
super().setUp()
self._spawn_processes()
if __name__ == "__main__":
run_tests()