diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 244db3c91e0..76be395ef47 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -554,6 +554,17 @@ class DeviceCachingAllocator { } } + double getMemoryFraction() { + if (!set_fraction) { + return 1.0; + } + + c10::xpu::DeviceProp device_prop; + c10::xpu::get_device_properties(&device_prop, device_index); + return static_cast(allowed_memory_maximum) / + static_cast(device_prop.global_mem_size); + } + void setMemoryFraction(double fraction) { c10::xpu::DeviceProp device_prop; c10::xpu::get_device_properties(&device_prop, device_index); @@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator { device_allocators[device]->resetAccumulatedStats(); } + double getMemoryFraction(DeviceIndex device) { + assertValidDevice(device); + return device_allocators[device]->getMemoryFraction(); + } + void setMemoryFraction(double fraction, DeviceIndex device) { assertValidDevice(device); TORCH_CHECK_VALUE( @@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) { return allocator.recordStream(dataPtr, stream); } +double getMemoryFraction(DeviceIndex device) { + return allocator.getMemoryFraction(device); +} + void setMemoryFraction(double fraction, DeviceIndex device) { return allocator.setMemoryFraction(fraction, device); } diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 44ac34fe9a9..b0b0f2ca969 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr); C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream); +C10_XPU_API double getMemoryFraction(DeviceIndex device); + C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); } // namespace c10::xpu::XPUCachingAllocator diff --git a/docs/source/xpu.md b/docs/source/xpu.md index 7a10e29b6af..6cd82aa9841 100644 --- a/docs/source/xpu.md +++ b/docs/source/xpu.md @@ -76,6 +76,7 @@ :nosignatures: empty_cache + get_per_process_memory_fraction max_memory_allocated max_memory_reserved mem_get_info diff --git a/test/test_xpu.py b/test/test_xpu.py index 9daa4b55011..61dd91e5bfa 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -489,6 +489,7 @@ if __name__ == "__main__": torch.xpu.empty_cache() total_memory = torch.xpu.get_device_properties().total_memory fraction = 0.5 + orig_fraction = torch.xpu.get_per_process_memory_fraction() with self.assertRaisesRegex(ValueError, "invalid fraction:"): torch.xpu.set_per_process_memory_fraction(-0.1) with self.assertRaisesRegex(ValueError, "invalid fraction:"): @@ -503,11 +504,13 @@ if __name__ == "__main__": gc.collect() torch.xpu.empty_cache() + self.assertEqual(fraction, torch.xpu.get_per_process_memory_fraction()) + application_memory = int(total_memory * 0.51) with self.assertRaises(torch.OutOfMemoryError): _ = torch.empty(application_memory, dtype=torch.int8, device="xpu") - torch.xpu.set_per_process_memory_fraction(1.0) + torch.xpu.set_per_process_memory_fraction(orig_fraction) def test_memory_allocation(self): torch.xpu.empty_cache() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f2ec106d710..10178c9fbf4 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2398,6 +2398,7 @@ def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... def _xpu_resetPeakMemoryStats(device: _int) -> None: ... def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ... def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ... +def _xpu_getMemoryFraction(device: _int) -> _float: ... def _xpu_setMemoryFraction(fraction: _float, device: _int) -> None: ... class _XpuDeviceProperties: diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index ff5e82af42f..44d11a5bd97 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -420,6 +420,9 @@ static void initXpuMethodBindings(PyObject* module) { [](c10::DeviceIndex device, c10::DeviceIndex peer) { return at::xpu::canDeviceAccessPeer(device, peer); }); + m.def("_xpu_getMemoryFraction", [](c10::DeviceIndex device) { + return c10::xpu::XPUCachingAllocator::getMemoryFraction(device); + }); m.def("_xpu_setMemoryFraction", [](double fraction, c10::DeviceIndex device) { c10::xpu::XPUCachingAllocator::setMemoryFraction(fraction, device); }); diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 5fec24c74de..0d5500d68d0 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -521,6 +521,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: # import here to avoid circular import from .memory import ( empty_cache, + get_per_process_memory_fraction, max_memory_allocated, max_memory_reserved, mem_get_info, @@ -562,6 +563,7 @@ __all__ = [ "get_device_name", "get_device_properties", "get_gencode_flags", + "get_per_process_memory_fraction", "get_rng_state", "get_rng_state_all", "get_stream_from_external", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py index dd08573b48c..069d93cefa9 100644 --- a/torch/xpu/memory.py +++ b/torch/xpu/memory.py @@ -194,6 +194,26 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]: return torch._C._xpu_getMemoryInfo(device) +def get_per_process_memory_fraction(device: _device_t = None) -> float: + r""" + Retrieve the memory fraction currently set for a process on a given XPU device. + This fraction represents the portion of the total device memory that + the caching allocator is allowed to use. The allowed memory is calculated as: + + .. math:: \text{allowed\_memory} = \text{total\_memory} \times \text{fraction} + + Args: + device (torch.device or int or str, optional): selected device. It uses the current device, + given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). + + Returns: + float: The memory fraction in the range 0.0 to 1.0. + """ + _lazy_init() + device = _get_device_index(device, optional=True) + return torch._C._xpu_getMemoryFraction(device) + + def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -> None: r""" Set the memory fraction for a single process on XPU device. @@ -222,6 +242,7 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) - __all__ = [ "empty_cache", + "get_per_process_memory_fraction", "max_memory_allocated", "max_memory_reserved", "mem_get_info",