[Typing] Refactor torch.types.Device in torch/cuda/__init__.py (#153447)

Part of: #152952
Follow up: #153027

Here is the definition of `torch.types.Device`:

ab997d9ff5/torch/types.py (L74)

So `Optional[Union[Device, int]]` is equivalent to `torch.types.Device`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153447
Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
Yuanhao Ji 2025-05-28 10:09:31 +00:00 committed by PyTorch MergeBot
parent fdc339003b
commit f58143b945

View File

@ -18,13 +18,12 @@ import threading
import traceback import traceback
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, cast, Optional, Union from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
import torch import torch
import torch._C import torch._C
from torch import device as _device from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker, classproperty from torch._utils import _dummy_type, _LazySeedTracker, classproperty
from torch.types import Device
from . import gds from . import gds
from ._utils import _get_device_index from ._utils import _get_device_index
@ -38,6 +37,9 @@ from .graphs import (
from .streams import Event, ExternalStream, Stream from .streams import Event, ExternalStream, Stream
if TYPE_CHECKING:
from torch.types import Device
try: try:
from torch._C import _cudart # type: ignore[attr-defined] from torch._C import _cudart # type: ignore[attr-defined]
except ImportError: except ImportError:
@ -50,7 +52,6 @@ _queued_calls: list[
tuple[Callable[[], None], list[str]] tuple[Callable[[], None], list[str]]
] = [] # don't invoke these until initialization occurs ] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_HAS_PYNVML = False _HAS_PYNVML = False
_PYNVML_ERR = None _PYNVML_ERR = None
@ -208,7 +209,7 @@ def is_bf16_supported(including_emulation: bool = True):
@lru_cache(maxsize=16) @lru_cache(maxsize=16)
def _check_bf16_tensor_supported(device: _device_t): def _check_bf16_tensor_supported(device: "Device"):
try: try:
torch.tensor([1.0], dtype=torch.bfloat16, device=device) torch.tensor([1.0], dtype=torch.bfloat16, device=device)
return True return True
@ -524,7 +525,7 @@ class device_of(device):
super().__init__(idx) super().__init__(idx)
def set_device(device: _device_t) -> None: def set_device(device: "Device") -> None:
r"""Set the current device. r"""Set the current device.
Usage of this function is discouraged in favor of :any:`device`. In most Usage of this function is discouraged in favor of :any:`device`. In most
@ -539,7 +540,7 @@ def set_device(device: _device_t) -> None:
torch._C._cuda_setDevice(device) torch._C._cuda_setDevice(device)
def get_device_name(device: Optional[_device_t] = None) -> str: def get_device_name(device: "Device" = None) -> str:
r"""Get the name of a device. r"""Get the name of a device.
Args: Args:
@ -554,7 +555,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
return get_device_properties(device).name return get_device_properties(device).name
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]: def get_device_capability(device: "Device" = None) -> tuple[int, int]:
r"""Get the cuda capability of a device. r"""Get the cuda capability of a device.
Args: Args:
@ -571,7 +572,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]
return prop.major, prop.minor return prop.major, prop.minor
def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProperties: def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
r"""Get the properties of a device. r"""Get the properties of a device.
Args: Args:
@ -590,7 +591,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProp
return _get_device_properties(device) # type: ignore[name-defined] return _get_device_properties(device) # type: ignore[name-defined]
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: def can_device_access_peer(device: "Device", peer_device: "Device") -> bool:
r"""Check if peer access between two devices is possible.""" r"""Check if peer access between two devices is possible."""
_lazy_init() _lazy_init()
device = _get_device_index(device, optional=True) device = _get_device_index(device, optional=True)
@ -967,7 +968,7 @@ def _device_count_nvml() -> int:
return len(visible_devices) return len(visible_devices)
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: def _get_nvml_device_index(device: "Device") -> int:
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
idx = _get_device_index(device, optional=True) idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices() visible_devices = _parse_visible_devices()
@ -1042,7 +1043,7 @@ def current_device() -> int:
return torch._C._cuda_getDevice() return torch._C._cuda_getDevice()
def synchronize(device: Optional[_device_t] = None) -> None: def synchronize(device: "Device" = None) -> None:
r"""Wait for all kernels in all streams on a CUDA device to complete. r"""Wait for all kernels in all streams on a CUDA device to complete.
Args: Args:
@ -1068,7 +1069,7 @@ def ipc_collect():
return torch._C._cuda_ipc_collect() return torch._C._cuda_ipc_collect()
def current_stream(device: Optional[_device_t] = None) -> Stream: def current_stream(device: "Device" = None) -> Stream:
r"""Return the currently selected :class:`Stream` for a given device. r"""Return the currently selected :class:`Stream` for a given device.
Args: Args:
@ -1086,7 +1087,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream:
) )
def default_stream(device: Optional[_device_t] = None) -> Stream: def default_stream(device: "Device" = None) -> Stream:
r"""Return the default :class:`Stream` for a given device. r"""Return the default :class:`Stream` for a given device.
Args: Args:
@ -1104,9 +1105,7 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
) )
def get_stream_from_external( def get_stream_from_external(data_ptr: int, device: "Device" = None) -> Stream:
data_ptr: int, device: Optional[_device_t] = None
) -> Stream:
r"""Return a :class:`Stream` from an externally allocated CUDA stream. r"""Return a :class:`Stream` from an externally allocated CUDA stream.
This function is used to wrap streams allocated in other libraries in order This function is used to wrap streams allocated in other libraries in order
@ -1171,7 +1170,7 @@ def get_sync_debug_mode() -> int:
return torch._C._cuda_get_sync_debug_mode() return torch._C._cuda_get_sync_debug_mode()
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): def _get_pynvml_handler(device: "Device" = None):
if not _HAS_PYNVML: if not _HAS_PYNVML:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"pynvml does not seem to be installed or it can't be imported." "pynvml does not seem to be installed or it can't be imported."
@ -1188,7 +1187,7 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
return handle return handle
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): def _get_amdsmi_handler(device: "Device" = None):
if not _HAS_PYNVML: if not _HAS_PYNVML:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"amdsmi does not seem to be installed or it can't be imported." "amdsmi does not seem to be installed or it can't be imported."
@ -1204,7 +1203,7 @@ def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
return handle return handle
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: def _get_amdsmi_device_index(device: "Device") -> int:
r"""Return the amdsmi index of the device, taking visible_devices into account.""" r"""Return the amdsmi index of the device, taking visible_devices into account."""
idx = _get_device_index(device, optional=True) idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices() visible_devices = _parse_visible_devices()
@ -1224,7 +1223,7 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
return idx_map[idx] return idx_map[idx]
def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_device_memory_used(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
# amdsmi_get_gpu_vram_usage returns mem usage in megabytes # amdsmi_get_gpu_vram_usage returns mem usage in megabytes
mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
@ -1232,17 +1231,17 @@ def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None)
return mem_bytes return mem_bytes
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_memory_usage(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"] return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"]
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_utilization(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_temperature(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_temp_metric( return amdsmi.amdsmi_get_temp_metric(
handle, handle,
@ -1251,7 +1250,7 @@ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
) )
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_power_draw(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
if socket_power != "N/A": if socket_power != "N/A":
@ -1264,7 +1263,7 @@ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
return 0 return 0
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_clock_rate(device: "Device" = None) -> int:
handle = _get_amdsmi_handler(device) handle = _get_amdsmi_handler(device)
clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
if "cur_clk" in clock_info: # ROCm 6.2 deprecation if "cur_clk" in clock_info: # ROCm 6.2 deprecation
@ -1277,7 +1276,7 @@ def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
return 0 return 0
def device_memory_used(device: Optional[Union[Device, int]] = None) -> int: def device_memory_used(device: "Device" = None) -> int:
r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`. r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`.
Args: Args:
@ -1295,7 +1294,7 @@ def device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_device_memory_used(device) return _get_amdsmi_device_memory_used(device)
def memory_usage(device: Optional[Union[Device, int]] = None) -> int: def memory_usage(device: "Device" = None) -> int:
r"""Return the percent of time over the past sample period during which global (device) r"""Return the percent of time over the past sample period during which global (device)
memory was being read or written as given by `nvidia-smi`. memory was being read or written as given by `nvidia-smi`.
@ -1316,7 +1315,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_memory_usage(device) return _get_amdsmi_memory_usage(device)
def utilization(device: Optional[Union[Device, int]] = None) -> int: def utilization(device: "Device" = None) -> int:
r"""Return the percent of time over the past sample period during which one or r"""Return the percent of time over the past sample period during which one or
more kernels was executing on the GPU as given by `nvidia-smi`. more kernels was executing on the GPU as given by `nvidia-smi`.
@ -1337,7 +1336,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_utilization(device) return _get_amdsmi_utilization(device)
def temperature(device: Optional[Union[Device, int]] = None) -> int: def temperature(device: "Device" = None) -> int:
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades). r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
The average temperature is computed based on past sample period as given by `nvidia-smi`. The average temperature is computed based on past sample period as given by `nvidia-smi`.
@ -1358,7 +1357,7 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_temperature(device) return _get_amdsmi_temperature(device)
def power_draw(device: Optional[Union[Device, int]] = None) -> int: def power_draw(device: "Device" = None) -> int:
r"""Return the average power draw of the GPU sensor in mW (MilliWatts) r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
@ -1377,7 +1376,7 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
return _get_amdsmi_power_draw(device) return _get_amdsmi_power_draw(device)
def clock_rate(device: Optional[Union[Device, int]] = None) -> int: def clock_rate(device: "Device" = None) -> int:
r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`. r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`.
Args: Args: