mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Generalization of distributed test cases for non-CUDA devices (#138216)
# Motivation This pr is an extension of #131758. As described in #131758, these changes are looking to make distributed UTs more accessible to users of all device types. It is a demonstration of a few changes discussed by @kwen2501 and @jgong5 in the discussion for #131758(https://github.com/pytorch/pytorch/pull/131758#discussion_r1762422784) This PR contains two types of changes, the first is to the common distributed folder where we have added a new class derived from MultiProcessTestCase which helps abstracts out the process group creation /deletion and other functionality for a given device. The new generalized content can be added by deriving from this base class. Also includes other misc changes for gaudi support The second changed file is test_functional_api. a test file in common distributed. This file is a POC for how we can use this new class to write more device agnostic distributed test cases. The following changes have been made to test_functional_api.py: -Functionality has been added to test for non cuda devices using intel HPU as an example -Multiple set up steps previously required by MultiProcessTestCase have been abstracted out -Misc adaptations to allow for general call to accelerators while adding test skips instead explicitly skipping for multiple GPUs -Skipifhpu flags have been added to enable skipping a few Multithreaded test cases which are as yet not supported on HPUs NOTE: Within test functional api, there are tests which require the use of some multithreading functions which are as yet not supported on HPUs. These have been skipped for hpu using skipHPU decorator. I will be raising a separate PR to improve usability pf said decorators in a device agnostic setting in the manner suggested by @kwen2501 in a comment on this PR. This pr is a cleaned up version of a previous PR(#136988) which I closed due to human error. I have addressed some of the comments made by @kwen2501 in this as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/138216 Approved by: https://github.com/kwen2501, https://github.com/guangyey
This commit is contained in:
parent
06dde8c157
commit
b379a28a95
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial, wraps
|
||||
|
|
@ -13,6 +12,7 @@ import torch.distributed.distributed_c10d as c10d
|
|||
from functorch import make_fx
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ if not dist.is_available():
|
|||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
DistributedTestBase,
|
||||
MultiThreadedTestCase,
|
||||
requires_nccl,
|
||||
TEST_SKIPS,
|
||||
|
|
@ -31,10 +31,43 @@ from torch.testing._internal.common_utils import (
|
|||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: Instructions for adding new device types to this test file
|
||||
#
|
||||
# This test file contains two types of tests:
|
||||
# 1. Tests that run on both CPUs and accelerators
|
||||
# 2. Tests that run only on accelerators
|
||||
#
|
||||
# We use two variables to manage device types:
|
||||
# - `devices`: A list containing device types for both CPU and accelerator tests
|
||||
# - `DEVICE`: A string containing only the accelerator type for accelerator-only tests
|
||||
#
|
||||
# To add a new device type:
|
||||
# 1. Add a new `elif` statement in the if-else ladder below
|
||||
# 2. Check for the presence of your device (e.g., TEST_NEW_DEVICE)
|
||||
# 3. Append your device type to the `devices` list
|
||||
# 4. Assign your device type string to `DEVICE`
|
||||
#
|
||||
# Example:
|
||||
# elif TEST_NEW_DEVICE:
|
||||
# devices.append("new_device")
|
||||
# DEVICE = "new_device"
|
||||
|
||||
DEVICE = "cuda"
|
||||
devices = ["cpu"]
|
||||
if TEST_HPU:
|
||||
devices.append("hpu")
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
devices.append("cuda")
|
||||
|
||||
|
||||
def new_subgroups(group_size: int, pg_tag=None):
|
||||
world_size = dist.get_world_size()
|
||||
subgroups = []
|
||||
|
|
@ -57,6 +90,7 @@ def new_subgroups(group_size: int, pg_tag=None):
|
|||
return cur_subgroup, subgroups
|
||||
|
||||
|
||||
@skipIfHpu
|
||||
class TestExpand(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -146,6 +180,7 @@ class TestExpand(MultiThreadedTestCase):
|
|||
self.assertEqual(2, group_size)
|
||||
|
||||
|
||||
@skipIfHpu
|
||||
class TestPgTag(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -222,6 +257,7 @@ class TestPgTag(MultiThreadedTestCase):
|
|||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@skipIfHpu
|
||||
class TestTraceableCollectives(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -231,7 +267,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_broadcast(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -247,7 +283,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
res = ft_c.broadcast(tensor, 0, mesh)
|
||||
self.assertEqual(res, torch.ones([4], device=device))
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_all_reduce_eager(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -264,7 +300,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
|
||||
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_all_reduce_coalesced_eager(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -279,7 +315,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
self.assertEqual(res[0], t0 * 4)
|
||||
self.assertEqual(res[1], t1 * 4)
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_all_gather_tensor(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -301,7 +337,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
)
|
||||
self.assertEqual(gathered_tensor, torch.ones(output_size))
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_all_gather_into_tensor_coalesced(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -318,7 +354,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
|
||||
)
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_reduce_scatter_tensor(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -342,7 +378,7 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
)
|
||||
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
|
||||
|
||||
@parametrize("device", ["cpu", "cuda"])
|
||||
@parametrize("device", devices)
|
||||
def test_reduce_scatter_into_tensor_coalesced(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
|
|
@ -367,6 +403,7 @@ class TestMetaCollectives(TestCase):
|
|||
self.assertEqual(x.size(), out.size())
|
||||
|
||||
|
||||
@skipIfHpu
|
||||
class TestGradCollectives(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -430,82 +467,55 @@ class TestMakeFx(TestCase):
|
|||
|
||||
|
||||
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
||||
WORLD_SIZE = 2
|
||||
|
||||
# Adding support for HCCL backend
|
||||
# To add a different backend
|
||||
# add an elif to the same chain with a conditional checking for the device type (along the lines of TEST_HPU or TEST_CUDA)
|
||||
# And then set the BACKEND variable appropriately.
|
||||
if TEST_HPU:
|
||||
BACKEND = dist.Backend.HCCL
|
||||
|
||||
|
||||
def exit_if_lt_x_gpu(x):
|
||||
if torch.cuda.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
# allows you to check for multiple accelerator irrespective of device type
|
||||
# to add new device types to this check simply follow the same format
|
||||
# and append an elif with the conditional and appropriate device count function for your new device
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if TEST_CUDA:
|
||||
if torch.cuda.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
elif TEST_HPU:
|
||||
if torch.hpu.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
if func is None:
|
||||
return partial(
|
||||
with_comms,
|
||||
)
|
||||
return partial(with_comms)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
global BACKEND
|
||||
|
||||
if "BACKEND" in os.environ:
|
||||
BACKEND = os.environ["BACKEND"]
|
||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
self.dist_init()
|
||||
func(self)
|
||||
self.destroy_comms()
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
self.pg = self.create_pg(device=DEVICE)
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestCollectivesWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = dist.Backend.NCCL
|
||||
BACKEND = dist.Backend.NCCL
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device(self.rank)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return WORLD_SIZE
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.group.WORLD
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
def destroy_comms(self):
|
||||
# Wait for all ranks to reach here before starting shutdown.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_gather_into_tensor_coalesced(self):
|
||||
exit_if_lt_x_gpu(self.world_size)
|
||||
|
||||
def test_all_gather_into_tensor_coalesced(self, device):
|
||||
exit_if_lt_x_accelerators(self.world_size)
|
||||
tensors = [
|
||||
torch.ones([4], device=f"cuda:{self.rank}"),
|
||||
torch.ones([4], device=f"cuda:{self.rank}") + 1,
|
||||
torch.ones([4], device=device),
|
||||
torch.ones([4], device=device) + 1,
|
||||
]
|
||||
mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
|
||||
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
|
||||
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
|
||||
self.assertEqual(2, len(res))
|
||||
|
|
@ -513,8 +523,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
|
||||
|
||||
@with_comms()
|
||||
def test_all_to_all_single(self):
|
||||
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
||||
def test_all_to_all_single(self, device):
|
||||
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
rank = dist.get_rank()
|
||||
|
||||
|
|
@ -531,8 +540,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
self.assertEqual(y, expected)
|
||||
|
||||
@with_comms()
|
||||
def test_all_to_all_single_1d_input(self):
|
||||
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
||||
def test_all_to_all_single_1d_input(self, device):
|
||||
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
rank = dist.get_rank()
|
||||
|
||||
|
|
@ -549,8 +557,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
self.assertEqual(y, expected)
|
||||
|
||||
@with_comms()
|
||||
def test_all_to_all_single_split_sizes_none(self):
|
||||
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
||||
def test_all_to_all_single_split_sizes_none(self, device):
|
||||
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
rank = dist.get_rank()
|
||||
|
||||
|
|
@ -567,16 +574,16 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@requires_nccl()
|
||||
@with_comms()
|
||||
def test_tracing(self):
|
||||
def test_tracing(self, device):
|
||||
def allreduce(t, pg):
|
||||
return ft_c.all_reduce(t, "sum", pg)
|
||||
|
||||
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
|
||||
compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
|
||||
compiled_allreduce(torch.randn(8, device=device), self.pg)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_tracing_with_fakepg(self):
|
||||
exit_if_lt_x_gpu(self.world_size)
|
||||
def test_tracing_with_fakepg(self, device=DEVICE):
|
||||
exit_if_lt_x_accelerators(self.world_size)
|
||||
|
||||
def allreduce(t, pg):
|
||||
return ft_c.all_reduce(t, "sum", pg)
|
||||
|
|
@ -588,12 +595,13 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
world_size=8,
|
||||
store=FakeStore(),
|
||||
)
|
||||
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
|
||||
allreduce(torch.randn(8, device=device), pg=dist.group.WORLD)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@requires_nccl()
|
||||
@with_comms()
|
||||
def test_tracing_with_dce_code(self):
|
||||
def test_tracing_with_dce_code(self, device):
|
||||
if self.world_size > 2:
|
||||
return
|
||||
|
||||
|
|
@ -608,22 +616,21 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|||
|
||||
compiled_func = torch.compile(func)
|
||||
ret = compiled_func(
|
||||
torch.ones((100,), device="cuda"), self.process_group, self.rank
|
||||
torch.ones((100,), device=device), self.process_group, self.rank
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
|
||||
class TestDistributedBackendCollectivesWithWorldSize4(
|
||||
TestCollectivesWithDistributedBackend
|
||||
):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
@requires_nccl()
|
||||
@with_comms()
|
||||
def test_permute_tensor_with_sub_group(self):
|
||||
exit_if_lt_x_gpu(self.world_size)
|
||||
|
||||
device = "cuda"
|
||||
def test_permute_tensor_with_sub_group(self, device):
|
||||
exit_if_lt_x_accelerators(self.world_size)
|
||||
mesh_dim_names = ["dp", "tp"]
|
||||
|
||||
mesh_2d = dt.init_device_mesh(
|
||||
|
|
@ -651,6 +658,7 @@ class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
|
|||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@skipIfHpu
|
||||
class TestFunctionalAutograd(MultiThreadedTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
@ -784,48 +792,11 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
|||
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
|
||||
|
||||
|
||||
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = dist.Backend.NCCL
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device(self.rank)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.group.WORLD
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
def destroy_comms(self):
|
||||
# Wait for all ranks to reach here before starting shutdown.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_to_all_single(self) -> None:
|
||||
group = self.process_group.group_name
|
||||
|
||||
t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
|
||||
def test_all_to_all_single(self, device) -> None:
|
||||
group = self.pg
|
||||
t = torch.ones((self.world_size, 2), requires_grad=True, device=device)
|
||||
|
||||
sizes = [1] * self.world_size
|
||||
assert t.requires_grad
|
||||
|
|
@ -840,5 +811,16 @@ class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
|
|||
self.assertEqual(t.grad, torch.full_like(t, 2.0))
|
||||
|
||||
|
||||
# Update the supported devices in DEVICE
|
||||
instantiate_device_type_tests(
|
||||
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_TSAN,
|
||||
TestCase,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
)
|
||||
from torch.testing._internal.distributed.multi_threaded_pg import (
|
||||
_install_threaded_pg,
|
||||
|
|
@ -82,6 +83,7 @@ TEST_SKIPS = {
|
|||
86, "Test skipped at subprocess level, look at subprocess log for skip reason"
|
||||
),
|
||||
"importerror": TestSkip(88, "Test skipped due to missing import"),
|
||||
"no_accelerator": TestSkip(89, "accelerator is not available."),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -101,6 +103,8 @@ class DistTestCases:
|
|||
backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
|
||||
backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
|
||||
backend_feature["plugin"] = set()
|
||||
if TEST_HPU:
|
||||
backend_feature["hpu"] = {"hccl"}
|
||||
|
||||
|
||||
def skip_if_no_gpu(func):
|
||||
|
|
@ -114,6 +118,8 @@ def skip_if_no_gpu(func):
|
|||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
if torch.cuda.device_count() < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||
if TEST_HPU and torch.hpu.device_count < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
|
@ -191,6 +197,8 @@ def skip_if_lt_x_gpu(x):
|
|||
def wrapper(*args, **kwargs):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
if TEST_HPU and torch.hpu.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
return wrapper
|
||||
|
|
@ -500,6 +508,9 @@ def init_multigpu_helper(world_size: int, backend: str):
|
|||
divided to subsets, each process only uses a subset.
|
||||
"""
|
||||
nGPUs = torch.cuda.device_count()
|
||||
if TEST_HPU:
|
||||
nGPUs = torch.hpu.device_count()
|
||||
|
||||
visible_devices = range(nGPUs)
|
||||
|
||||
# If rank is less than or equal to number of available GPU's
|
||||
|
|
@ -900,6 +911,47 @@ class MultiProcessTestCase(TestCase):
|
|||
def is_master(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
# Utility base class for distributed Multi Process Test cases
|
||||
# This abstracts the PG creation and deletion, the backends are selected based
|
||||
# on device type. The tests functions can be instantiated per device type using
|
||||
# common_device_type.instantiate_device_type_tests
|
||||
# other backends can add entry in backend() function
|
||||
class DistributedTestBase(MultiProcessTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def backend(self, device) -> str:
|
||||
if "cuda" in device:
|
||||
return "nccl"
|
||||
elif "hpu" in device : # intel gaudi
|
||||
return "hccl"
|
||||
else :
|
||||
return "gloo"
|
||||
|
||||
def create_pg(self, device):
|
||||
num_visible_devices = torch.get_device_module(device).device_count()
|
||||
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
|
||||
torch.distributed.init_process_group(
|
||||
backend=self.backend(device),
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store
|
||||
)
|
||||
if "nccl" in self.backend(device):
|
||||
torch.cuda.set_device(self.rank)
|
||||
return torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
def rank_to_device(self, device):
|
||||
num_visible_devices = torch.get_device_module(device).device_count()
|
||||
return {i: [i % num_visible_devices] for i in range(self.world_size)}
|
||||
|
||||
def run_subtests(
|
||||
cls_inst,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user