mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Differential Revision: [D46883396](https://our.internmc.facebook.com/intern/diff/D46883396/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/103938 Approved by: https://github.com/awgu, https://github.com/fegin
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN
|
|
from torch.testing._internal.distributed.distributed_test import (
|
|
DistributedTest, TestDistBackend
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if NO_MULTIPROCESSING_SPAWN:
|
|
print("Spawn not available, skipping tests.", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
BACKEND = os.environ["BACKEND"]
|
|
|
|
if BACKEND == "gloo" or BACKEND == "nccl" or BACKEND == "ucc":
|
|
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
|
|
else:
|
|
print(f"Invalid backend {BACKEND}. Tests will not be run!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|