mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use accelerator API in common_dtensor (#163498)
Fixes #ISSUE_NUMBER Try to unify the device checking in common_dtensor (testing module) by accelerator API Pull Request resolved: https://github.com/pytorch/pytorch/pull/163498 Approved by: https://github.com/albanD, https://github.com/H-Huang
This commit is contained in:
parent
ebddbe787a
commit
6e5dddba64
|
|
@ -13,7 +13,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch._utils import _get_device_module
|
|
||||||
from torch.distributed.tensor import (
|
from torch.distributed.tensor import (
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
|
|
@ -38,24 +37,21 @@ from torch.testing._internal.common_distributed import (
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
TEST_SKIPS,
|
TEST_SKIPS,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
|
from torch.testing._internal.common_utils import (
|
||||||
|
TEST_CUDA,
|
||||||
|
TEST_HPU,
|
||||||
|
TEST_PRIVATEUSE1,
|
||||||
|
TEST_XPU,
|
||||||
|
)
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
||||||
|
|
||||||
|
|
||||||
DEVICE_COUNT: int
|
DEVICE_COUNT: int
|
||||||
|
|
||||||
if TEST_CUDA:
|
if TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1:
|
||||||
DEVICE_TYPE = "cuda"
|
DEVICE_TYPE = torch.accelerator.current_accelerator().type
|
||||||
PG_BACKEND = "nccl"
|
DEVICE_COUNT = torch.accelerator.device_count()
|
||||||
DEVICE_COUNT = _get_device_module("cuda").device_count()
|
PG_BACKEND = dist.Backend.default_device_backend_map[DEVICE_TYPE]
|
||||||
elif TEST_HPU:
|
|
||||||
DEVICE_TYPE = "hpu"
|
|
||||||
PG_BACKEND = "hccl"
|
|
||||||
DEVICE_COUNT = _get_device_module("hpu").device_count()
|
|
||||||
elif TEST_XPU:
|
|
||||||
DEVICE_TYPE = "xpu"
|
|
||||||
PG_BACKEND = "xccl"
|
|
||||||
DEVICE_COUNT = _get_device_module("xpu").device_count()
|
|
||||||
else:
|
else:
|
||||||
DEVICE_TYPE = "cpu"
|
DEVICE_TYPE = "cpu"
|
||||||
PG_BACKEND = "gloo"
|
PG_BACKEND = "gloo"
|
||||||
|
|
@ -63,7 +59,7 @@ else:
|
||||||
NUM_DEVICES = 4
|
NUM_DEVICES = 4
|
||||||
|
|
||||||
# We use this as a proxy for "multiple GPUs exist"
|
# We use this as a proxy for "multiple GPUs exist"
|
||||||
if (TEST_CUDA or TEST_XPU or TEST_HPU) and DEVICE_COUNT > 1:
|
if (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1) and DEVICE_COUNT > 1:
|
||||||
# when we actually have multiple GPUs, relax the requirement to smaller counts.
|
# when we actually have multiple GPUs, relax the requirement to smaller counts.
|
||||||
NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
|
NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
|
||||||
|
|
||||||
|
|
@ -341,7 +337,10 @@ class DTensorContinuousTestBase(MultiProcContinuousTest):
|
||||||
@classmethod
|
@classmethod
|
||||||
def device_type(cls) -> str:
|
def device_type(cls) -> str:
|
||||||
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
|
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
|
||||||
if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < cls.world_size:
|
if (
|
||||||
|
not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
|
||||||
|
or DEVICE_COUNT < cls.world_size
|
||||||
|
):
|
||||||
return "cpu"
|
return "cpu"
|
||||||
else:
|
else:
|
||||||
return DEVICE_TYPE
|
return DEVICE_TYPE
|
||||||
|
|
@ -360,7 +359,10 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||||
@property
|
@property
|
||||||
def device_type(self) -> str:
|
def device_type(self) -> str:
|
||||||
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
|
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
|
||||||
if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < self.world_size:
|
if (
|
||||||
|
not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
|
||||||
|
or DEVICE_COUNT < self.world_size
|
||||||
|
):
|
||||||
return "cpu"
|
return "cpu"
|
||||||
else:
|
else:
|
||||||
return DEVICE_TYPE
|
return DEVICE_TYPE
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user