[Typing] Refactor torch.types.Device in torch/cuda/__init__.py (#153447)

Part of: #152952
Follow up: #153027

Here is the definition of `torch.types.Device`:

ab997d9ff5/torch/types.py (L74)

So `Optional[Union[Device, int]]` is equivalent to `torch.types.Device`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153447
Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
Yuanhao Ji 2025-05-28 10:09:31 +00:00 committed by PyTorch MergeBot
parent fdc339003b
commit f58143b945

View File

@ -18,13 +18,12 @@ import threading
import traceback
import warnings
from functools import lru_cache
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
import torch
import torch._C
from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
from torch.types import Device
from . import gds
from ._utils import _get_device_index
@ -38,6 +37,9 @@ from .graphs import (
from .streams import Event, ExternalStream, Stream
if TYPE_CHECKING:
from torch.types import Device
try:
from torch._C import _cudart # type: ignore[attr-defined]
except ImportError:
@ -50,7 +52,6 @@ _queued_calls: list[
tuple[Callable[[], None], list[str]]
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_HAS_PYNVML = False
_PYNVML_ERR = None
@ -208,7 +209,7 @@ def is_bf16_supported(including_emulation: bool = True):
@lru_cache(maxsize=16)
def _check_bf16_tensor_supported(device: _device_t):
def _check_bf16_tensor_supported(device: "Device"):
try:
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
return True
@ -524,7 +525,7 @@ class device_of(device):
super().__init__(idx)
def set_device(device: _device_t) -> None:
def set_device(device: "Device") -> None:
r"""Set the current device.
Usage of this function is discouraged in favor of :any:`device`. In most
@ -539,7 +540,7 @@ def set_device(device: _device_t) -> None:
torch._C._cuda_setDevice(device)
def get_device_name(device: Optional[_device_t] = None) -> str:
def get_device_name(device: "Device" = None) -> str:
r"""Get the name of a device.
Args:
@ -554,7 +555,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
return get_device_properties(device).name
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
def get_device_capability(device: "Device" = None) -> tuple[int, int]:
r"""Get the cuda capability of a device.
Args:
@ -571,7 +572,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]
return prop.major, prop.minor
def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProperties:
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
r"""Get the properties of a device.
Args:
@ -590,7 +591,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProp
return _get_device_properties(device) # type: ignore[name-defined]
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
def can_device_access_peer(device: "Device", peer_device: "Device") -> bool:
r"""Check if peer access between two devices is possible."""
_lazy_init()
device = _get_device_index(device, optional=True)
@ -967,7 +968,7 @@ def _device_count_nvml() -> int:
return len(visible_devices)
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
def _get_nvml_device_index(device: "Device") -> int:
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices()
@ -1042,7 +1043,7 @@ def current_device() -> int:
return torch._C._cuda_getDevice()
def synchronize(device: Optional[_device_t] = None) -> None:
def synchronize(device: "Device" = None) -> None:
r"""Wait for all kernels in all streams on a CUDA device to complete.
Args:
@ -1068,7 +1069,7 @@ def ipc_collect():
return torch._C._cuda_ipc_collect()
def current_stream(device: Optional[_device_t] = None) -> Stream:
def current_stream(device: "Device" = None) -> Stream:
r"""Return the currently selected :class:`Stream` for a given device.
Args:
@ -1086,7 +1087,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream:
)
def default_stream(device: Optional[_device_t] = None) -> Stream:
def default_stream(device: "Device" = None) -> Stream:
r"""Return the default :class:`Stream` for a given device.
Args:
@ -1104,9 +1105,7 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
)
def get_stream_from_external(
data_ptr: int, device: Optional[_device_t] = None
) -> Stream:
def get_stream_from_external(data_ptr: int, device: "Device" = None) -> Stream:
r"""Return a :class:`Stream` from an externally allocated CUDA stream.
This function is used to wrap streams allocated in other libraries in order
@ -1171,7 +1170,7 @@ def get_sync_debug_mode() -> int:
return torch._C._cuda_get_sync_debug_mode()
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
def _get_pynvml_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"pynvml does not seem to be installed or it can't be imported."
@ -1188,7 +1187,7 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
return handle
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
def _get_amdsmi_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"amdsmi does not seem to be installed or it can't be imported."
@ -1204,7 +1203,7 @@ def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
return handle
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
def _get_amdsmi_device_index(device: "Device") -> int:
r"""Return the amdsmi index of the device, taking visible_devices into account."""
idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices()
@ -1224,7 +1223,7 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
return idx_map[idx]
def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_device_memory_used(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
# amdsmi_get_gpu_vram_usage returns mem usage in megabytes
mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
@ -1232,17 +1231,17 @@ def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None)
return mem_bytes
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_memory_usage(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"]
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_utilization(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_temperature(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_temp_metric(
handle,
@ -1251,7 +1250,7 @@ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
)
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_power_draw(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
if socket_power != "N/A":
@ -1264,7 +1263,7 @@ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
return 0
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
def _get_amdsmi_clock_rate(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device)
clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
if "cur_clk" in clock_info: # ROCm 6.2 deprecation
@ -1277,7 +1276,7 @@ def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
return 0
def device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
def device_memory_used(device: "Device" = None) -> int:
r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`.
Args:
@ -1295,7 +1294,7 @@ def device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_device_memory_used(device)
def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
def memory_usage(device: "Device" = None) -> int:
r"""Return the percent of time over the past sample period during which global (device)
memory was being read or written as given by `nvidia-smi`.
@ -1316,7 +1315,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_memory_usage(device)
def utilization(device: Optional[Union[Device, int]] = None) -> int:
def utilization(device: "Device" = None) -> int:
r"""Return the percent of time over the past sample period during which one or
more kernels was executing on the GPU as given by `nvidia-smi`.
@ -1337,7 +1336,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_utilization(device)
def temperature(device: Optional[Union[Device, int]] = None) -> int:
def temperature(device: "Device" = None) -> int:
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
The average temperature is computed based on past sample period as given by `nvidia-smi`.
@ -1358,7 +1357,7 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_temperature(device)
def power_draw(device: Optional[Union[Device, int]] = None) -> int:
def power_draw(device: "Device" = None) -> int:
r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
@ -1377,7 +1376,7 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_power_draw(device)
def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
def clock_rate(device: "Device" = None) -> int:
r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`.
Args: