# Owner(s): ["oncall: distributed"] import os import sys import torch import torch.distributed as dist torch.backends.cuda.matmul.allow_tf32 = False if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN from torch.testing._internal.distributed.distributed_test import ( DistributedTest, TestDistBackend ) if TEST_WITH_DEV_DBG_ASAN: print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) sys.exit(0) if NO_MULTIPROCESSING_SPAWN: print("Spawn not available, skipping tests.", file=sys.stderr) sys.exit(0) BACKEND = os.environ["BACKEND"] if BACKEND == "gloo" or BACKEND == "nccl": class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase): def setUp(self): super().setUp() self._spawn_processes() torch.backends.cudnn.flags(allow_tf32=False).__enter__() if __name__ == "__main__": run_tests()