mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Typing] Refactor torch.types.Device in torch/cuda/__init__.py (#153447)
Part of: #152952
Follow up: #153027
Here is the definition of `torch.types.Device`:
ab997d9ff5/torch/types.py (L74)
So `Optional[Union[Device, int]]` is equivalent to `torch.types.Device`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153447
Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
parent
fdc339003b
commit
f58143b945
|
|
@ -18,13 +18,12 @@ import threading
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Callable, cast, Optional, Union
|
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C
|
import torch._C
|
||||||
from torch import device as _device
|
from torch import device as _device
|
||||||
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||||
from torch.types import Device
|
|
||||||
|
|
||||||
from . import gds
|
from . import gds
|
||||||
from ._utils import _get_device_index
|
from ._utils import _get_device_index
|
||||||
|
|
@ -38,6 +37,9 @@ from .graphs import (
|
||||||
from .streams import Event, ExternalStream, Stream
|
from .streams import Event, ExternalStream, Stream
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.types import Device
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._C import _cudart # type: ignore[attr-defined]
|
from torch._C import _cudart # type: ignore[attr-defined]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -50,7 +52,6 @@ _queued_calls: list[
|
||||||
tuple[Callable[[], None], list[str]]
|
tuple[Callable[[], None], list[str]]
|
||||||
] = [] # don't invoke these until initialization occurs
|
] = [] # don't invoke these until initialization occurs
|
||||||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||||
_device_t = Union[_device, str, int, None]
|
|
||||||
|
|
||||||
_HAS_PYNVML = False
|
_HAS_PYNVML = False
|
||||||
_PYNVML_ERR = None
|
_PYNVML_ERR = None
|
||||||
|
|
@ -208,7 +209,7 @@ def is_bf16_supported(including_emulation: bool = True):
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=16)
|
@lru_cache(maxsize=16)
|
||||||
def _check_bf16_tensor_supported(device: _device_t):
|
def _check_bf16_tensor_supported(device: "Device"):
|
||||||
try:
|
try:
|
||||||
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
||||||
return True
|
return True
|
||||||
|
|
@ -524,7 +525,7 @@ class device_of(device):
|
||||||
super().__init__(idx)
|
super().__init__(idx)
|
||||||
|
|
||||||
|
|
||||||
def set_device(device: _device_t) -> None:
|
def set_device(device: "Device") -> None:
|
||||||
r"""Set the current device.
|
r"""Set the current device.
|
||||||
|
|
||||||
Usage of this function is discouraged in favor of :any:`device`. In most
|
Usage of this function is discouraged in favor of :any:`device`. In most
|
||||||
|
|
@ -539,7 +540,7 @@ def set_device(device: _device_t) -> None:
|
||||||
torch._C._cuda_setDevice(device)
|
torch._C._cuda_setDevice(device)
|
||||||
|
|
||||||
|
|
||||||
def get_device_name(device: Optional[_device_t] = None) -> str:
|
def get_device_name(device: "Device" = None) -> str:
|
||||||
r"""Get the name of a device.
|
r"""Get the name of a device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -554,7 +555,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
|
||||||
return get_device_properties(device).name
|
return get_device_properties(device).name
|
||||||
|
|
||||||
|
|
||||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
def get_device_capability(device: "Device" = None) -> tuple[int, int]:
|
||||||
r"""Get the cuda capability of a device.
|
r"""Get the cuda capability of a device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -571,7 +572,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]
|
||||||
return prop.major, prop.minor
|
return prop.major, prop.minor
|
||||||
|
|
||||||
|
|
||||||
def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProperties:
|
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
|
||||||
r"""Get the properties of a device.
|
r"""Get the properties of a device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -590,7 +591,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _CudaDeviceProp
|
||||||
return _get_device_properties(device) # type: ignore[name-defined]
|
return _get_device_properties(device) # type: ignore[name-defined]
|
||||||
|
|
||||||
|
|
||||||
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
|
def can_device_access_peer(device: "Device", peer_device: "Device") -> bool:
|
||||||
r"""Check if peer access between two devices is possible."""
|
r"""Check if peer access between two devices is possible."""
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
device = _get_device_index(device, optional=True)
|
device = _get_device_index(device, optional=True)
|
||||||
|
|
@ -967,7 +968,7 @@ def _device_count_nvml() -> int:
|
||||||
return len(visible_devices)
|
return len(visible_devices)
|
||||||
|
|
||||||
|
|
||||||
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
def _get_nvml_device_index(device: "Device") -> int:
|
||||||
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
|
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
|
||||||
idx = _get_device_index(device, optional=True)
|
idx = _get_device_index(device, optional=True)
|
||||||
visible_devices = _parse_visible_devices()
|
visible_devices = _parse_visible_devices()
|
||||||
|
|
@ -1042,7 +1043,7 @@ def current_device() -> int:
|
||||||
return torch._C._cuda_getDevice()
|
return torch._C._cuda_getDevice()
|
||||||
|
|
||||||
|
|
||||||
def synchronize(device: Optional[_device_t] = None) -> None:
|
def synchronize(device: "Device" = None) -> None:
|
||||||
r"""Wait for all kernels in all streams on a CUDA device to complete.
|
r"""Wait for all kernels in all streams on a CUDA device to complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1068,7 +1069,7 @@ def ipc_collect():
|
||||||
return torch._C._cuda_ipc_collect()
|
return torch._C._cuda_ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
def current_stream(device: "Device" = None) -> Stream:
|
||||||
r"""Return the currently selected :class:`Stream` for a given device.
|
r"""Return the currently selected :class:`Stream` for a given device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1086,7 +1087,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
def default_stream(device: "Device" = None) -> Stream:
|
||||||
r"""Return the default :class:`Stream` for a given device.
|
r"""Return the default :class:`Stream` for a given device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1104,9 +1105,7 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_from_external(
|
def get_stream_from_external(data_ptr: int, device: "Device" = None) -> Stream:
|
||||||
data_ptr: int, device: Optional[_device_t] = None
|
|
||||||
) -> Stream:
|
|
||||||
r"""Return a :class:`Stream` from an externally allocated CUDA stream.
|
r"""Return a :class:`Stream` from an externally allocated CUDA stream.
|
||||||
|
|
||||||
This function is used to wrap streams allocated in other libraries in order
|
This function is used to wrap streams allocated in other libraries in order
|
||||||
|
|
@ -1171,7 +1170,7 @@ def get_sync_debug_mode() -> int:
|
||||||
return torch._C._cuda_get_sync_debug_mode()
|
return torch._C._cuda_get_sync_debug_mode()
|
||||||
|
|
||||||
|
|
||||||
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
|
def _get_pynvml_handler(device: "Device" = None):
|
||||||
if not _HAS_PYNVML:
|
if not _HAS_PYNVML:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"pynvml does not seem to be installed or it can't be imported."
|
"pynvml does not seem to be installed or it can't be imported."
|
||||||
|
|
@ -1188,7 +1187,7 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
|
def _get_amdsmi_handler(device: "Device" = None):
|
||||||
if not _HAS_PYNVML:
|
if not _HAS_PYNVML:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"amdsmi does not seem to be installed or it can't be imported."
|
"amdsmi does not seem to be installed or it can't be imported."
|
||||||
|
|
@ -1204,7 +1203,7 @@ def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
|
def _get_amdsmi_device_index(device: "Device") -> int:
|
||||||
r"""Return the amdsmi index of the device, taking visible_devices into account."""
|
r"""Return the amdsmi index of the device, taking visible_devices into account."""
|
||||||
idx = _get_device_index(device, optional=True)
|
idx = _get_device_index(device, optional=True)
|
||||||
visible_devices = _parse_visible_devices()
|
visible_devices = _parse_visible_devices()
|
||||||
|
|
@ -1224,7 +1223,7 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
|
||||||
return idx_map[idx]
|
return idx_map[idx]
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_device_memory_used(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
# amdsmi_get_gpu_vram_usage returns mem usage in megabytes
|
# amdsmi_get_gpu_vram_usage returns mem usage in megabytes
|
||||||
mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
|
mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
|
||||||
|
|
@ -1232,17 +1231,17 @@ def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None)
|
||||||
return mem_bytes
|
return mem_bytes
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_memory_usage(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"]
|
return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"]
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_utilization(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
|
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_temperature(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
return amdsmi.amdsmi_get_temp_metric(
|
return amdsmi.amdsmi_get_temp_metric(
|
||||||
handle,
|
handle,
|
||||||
|
|
@ -1251,7 +1250,7 @@ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_power_draw(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
|
socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
|
||||||
if socket_power != "N/A":
|
if socket_power != "N/A":
|
||||||
|
|
@ -1264,7 +1263,7 @@ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
def _get_amdsmi_clock_rate(device: "Device" = None) -> int:
|
||||||
handle = _get_amdsmi_handler(device)
|
handle = _get_amdsmi_handler(device)
|
||||||
clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
|
clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
|
||||||
if "cur_clk" in clock_info: # ROCm 6.2 deprecation
|
if "cur_clk" in clock_info: # ROCm 6.2 deprecation
|
||||||
|
|
@ -1277,7 +1276,7 @@ def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
|
def device_memory_used(device: "Device" = None) -> int:
|
||||||
r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`.
|
r"""Return used global (device) memory in bytes as given by `nvidia-smi` or `amd-smi`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -1295,7 +1294,7 @@ def device_memory_used(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return _get_amdsmi_device_memory_used(device)
|
return _get_amdsmi_device_memory_used(device)
|
||||||
|
|
||||||
|
|
||||||
def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
def memory_usage(device: "Device" = None) -> int:
|
||||||
r"""Return the percent of time over the past sample period during which global (device)
|
r"""Return the percent of time over the past sample period during which global (device)
|
||||||
memory was being read or written as given by `nvidia-smi`.
|
memory was being read or written as given by `nvidia-smi`.
|
||||||
|
|
||||||
|
|
@ -1316,7 +1315,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return _get_amdsmi_memory_usage(device)
|
return _get_amdsmi_memory_usage(device)
|
||||||
|
|
||||||
|
|
||||||
def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
def utilization(device: "Device" = None) -> int:
|
||||||
r"""Return the percent of time over the past sample period during which one or
|
r"""Return the percent of time over the past sample period during which one or
|
||||||
more kernels was executing on the GPU as given by `nvidia-smi`.
|
more kernels was executing on the GPU as given by `nvidia-smi`.
|
||||||
|
|
||||||
|
|
@ -1337,7 +1336,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return _get_amdsmi_utilization(device)
|
return _get_amdsmi_utilization(device)
|
||||||
|
|
||||||
|
|
||||||
def temperature(device: Optional[Union[Device, int]] = None) -> int:
|
def temperature(device: "Device" = None) -> int:
|
||||||
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
|
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
|
||||||
|
|
||||||
The average temperature is computed based on past sample period as given by `nvidia-smi`.
|
The average temperature is computed based on past sample period as given by `nvidia-smi`.
|
||||||
|
|
@ -1358,7 +1357,7 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return _get_amdsmi_temperature(device)
|
return _get_amdsmi_temperature(device)
|
||||||
|
|
||||||
|
|
||||||
def power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
def power_draw(device: "Device" = None) -> int:
|
||||||
r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
|
r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
|
||||||
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
|
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
|
||||||
|
|
||||||
|
|
@ -1377,7 +1376,7 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
return _get_amdsmi_power_draw(device)
|
return _get_amdsmi_power_draw(device)
|
||||||
|
|
||||||
|
|
||||||
def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
def clock_rate(device: "Device" = None) -> int:
|
||||||
r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`.
|
r"""Return the clock speed of the GPU SM in MHz (megahertz) over the past sample period as given by `nvidia-smi`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user