diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 2730fb4e4a1..bce66bc4921 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -642,6 +642,20 @@ def _device_count_nvml() -> int: return -1 return len(visible_devices) +def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: + r"""Returns the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + uuids = _raw_device_uuid_nvml() + if uuids is None: + raise RuntimeError("Can't get device UUIDs") + visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids) + idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))} + if idx not in idx_map: + raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})") + return idx_map[idx] + @lru_cache(maxsize=1) def device_count() -> int: r"""Returns the number of GPUs available.""" @@ -789,7 +803,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int: pynvml.nvmlInit() except NVMLError_DriverNotLoaded as e: raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e - device = _get_device_index(device, optional=True) + device = _get_nvml_device_index(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device) return pynvml.nvmlDeviceGetUtilizationRates(handle).memory @@ -815,7 +829,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int: pynvml.nvmlInit() except NVMLError_DriverNotLoaded as e: raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e - device = _get_device_index(device, optional=True) + device = _get_nvml_device_index(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device) return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 0a19604e07e..6e63ab2bf4d 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -5,7 +5,7 @@ import warnings from typing import Any, Dict, Union, Tuple import torch -from . import is_initialized, _get_device_index, _lazy_init +from . import is_initialized, _get_device_index, _lazy_init, _get_nvml_device_index from ._utils import _dummy_type from ._memory_viz import segments as _segments, memory as _memory @@ -587,7 +587,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: pynvml.nvmlInit() except NVMLError_DriverNotLoaded: return ("cuda driver can't be loaded, is cuda enabled?") - device = _get_device_index(device, optional=True) + device = _get_nvml_device_index(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device) procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) lines = []