Use accelerator API in common_dtensor (#163498)

Fixes #ISSUE_NUMBER

Try to unify the device checking in common_dtensor (testing module) by accelerator API

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163498
Approved by: https://github.com/albanD, https://github.com/H-Huang
This commit is contained in:
dilililiwhy 2025-09-23 16:30:20 +00:00 committed by PyTorch MergeBot
parent ebddbe787a
commit 6e5dddba64

View File

@ -13,7 +13,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch._utils import _get_device_module
from torch.distributed.tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_tensor, distribute_tensor,
@ -38,24 +37,21 @@ from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
TEST_SKIPS, TEST_SKIPS,
) )
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU from torch.testing._internal.common_utils import (
TEST_CUDA,
TEST_HPU,
TEST_PRIVATEUSE1,
TEST_XPU,
)
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
DEVICE_COUNT: int DEVICE_COUNT: int
if TEST_CUDA: if TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1:
DEVICE_TYPE = "cuda" DEVICE_TYPE = torch.accelerator.current_accelerator().type
PG_BACKEND = "nccl" DEVICE_COUNT = torch.accelerator.device_count()
DEVICE_COUNT = _get_device_module("cuda").device_count() PG_BACKEND = dist.Backend.default_device_backend_map[DEVICE_TYPE]
elif TEST_HPU:
DEVICE_TYPE = "hpu"
PG_BACKEND = "hccl"
DEVICE_COUNT = _get_device_module("hpu").device_count()
elif TEST_XPU:
DEVICE_TYPE = "xpu"
PG_BACKEND = "xccl"
DEVICE_COUNT = _get_device_module("xpu").device_count()
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
PG_BACKEND = "gloo" PG_BACKEND = "gloo"
@ -63,7 +59,7 @@ else:
NUM_DEVICES = 4 NUM_DEVICES = 4
# We use this as a proxy for "multiple GPUs exist" # We use this as a proxy for "multiple GPUs exist"
if (TEST_CUDA or TEST_XPU or TEST_HPU) and DEVICE_COUNT > 1: if (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1) and DEVICE_COUNT > 1:
# when we actually have multiple GPUs, relax the requirement to smaller counts. # when we actually have multiple GPUs, relax the requirement to smaller counts.
NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT) NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
@ -341,7 +337,10 @@ class DTensorContinuousTestBase(MultiProcContinuousTest):
@classmethod @classmethod
def device_type(cls) -> str: def device_type(cls) -> str:
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < cls.world_size: if (
not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
or DEVICE_COUNT < cls.world_size
):
return "cpu" return "cpu"
else: else:
return DEVICE_TYPE return DEVICE_TYPE
@ -360,7 +359,10 @@ class DTensorTestBase(MultiProcessTestCase):
@property @property
def device_type(self) -> str: def device_type(self) -> str:
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < self.world_size: if (
not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
or DEVICE_COUNT < self.world_size
):
return "cpu" return "cpu"
else: else:
return DEVICE_TYPE return DEVICE_TYPE