mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Take CUDA_VISIBLE_DEVICES into account for nvml calls (#94568)
Fixes #94472 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94568 Approved by: https://github.com/ngimel
This commit is contained in:
parent
ea657726d9
commit
dc4f2af6f6
|
|
@ -642,6 +642,20 @@ def _device_count_nvml() -> int:
|
||||||
return -1
|
return -1
|
||||||
return len(visible_devices)
|
return len(visible_devices)
|
||||||
|
|
||||||
|
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
||||||
|
r"""Returns the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
|
||||||
|
idx = _get_device_index(device, optional=True)
|
||||||
|
visible_devices = _parse_visible_devices()
|
||||||
|
if type(visible_devices[0]) is str:
|
||||||
|
uuids = _raw_device_uuid_nvml()
|
||||||
|
if uuids is None:
|
||||||
|
raise RuntimeError("Can't get device UUIDs")
|
||||||
|
visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
|
||||||
|
idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))}
|
||||||
|
if idx not in idx_map:
|
||||||
|
raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})")
|
||||||
|
return idx_map[idx]
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def device_count() -> int:
|
def device_count() -> int:
|
||||||
r"""Returns the number of GPUs available."""
|
r"""Returns the number of GPUs available."""
|
||||||
|
|
@ -789,7 +803,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
except NVMLError_DriverNotLoaded as e:
|
except NVMLError_DriverNotLoaded as e:
|
||||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
||||||
device = _get_device_index(device, optional=True)
|
device = _get_nvml_device_index(device)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
|
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
|
||||||
|
|
||||||
|
|
@ -815,7 +829,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
except NVMLError_DriverNotLoaded as e:
|
except NVMLError_DriverNotLoaded as e:
|
||||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
||||||
device = _get_device_index(device, optional=True)
|
device = _get_nvml_device_index(device)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
|
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import warnings
|
||||||
from typing import Any, Dict, Union, Tuple
|
from typing import Any, Dict, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import is_initialized, _get_device_index, _lazy_init
|
from . import is_initialized, _get_device_index, _lazy_init, _get_nvml_device_index
|
||||||
from ._utils import _dummy_type
|
from ._utils import _dummy_type
|
||||||
|
|
||||||
from ._memory_viz import segments as _segments, memory as _memory
|
from ._memory_viz import segments as _segments, memory as _memory
|
||||||
|
|
@ -587,7 +587,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
except NVMLError_DriverNotLoaded:
|
except NVMLError_DriverNotLoaded:
|
||||||
return ("cuda driver can't be loaded, is cuda enabled?")
|
return ("cuda driver can't be loaded, is cuda enabled?")
|
||||||
device = _get_device_index(device, optional=True)
|
device = _get_nvml_device_index(device)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||||
lines = []
|
lines = []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user