Introduce a new API torch.xpu.get_per_process_memory_fraction (#165511)

# Motivation
Aligned with other backends, this PR introduces a new API torch.xpu.get_per_process_memory_fraction to allow user to retrieve the allowed memory fraction per a single process.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165511
Approved by: https://github.com/EikanWang, https://github.com/ezyang
ghstack dependencies: #165508, #165509, #165510
This commit is contained in:
Yu, Guangye 2025-10-15 23:38:06 +00:00 committed by PyTorch MergeBot
parent 8221ee6db9
commit 0ec0549823
8 changed files with 54 additions and 1 deletions

View File

@ -554,6 +554,17 @@ class DeviceCachingAllocator {
}
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator {
device_allocators[device]->resetAccumulatedStats();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
}
void setMemoryFraction(double fraction, DeviceIndex device) {
assertValidDevice(device);
TORCH_CHECK_VALUE(
@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
}
void setMemoryFraction(double fraction, DeviceIndex device) {
return allocator.setMemoryFraction(fraction, device);
}

View File

@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr);
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
C10_XPU_API double getMemoryFraction(DeviceIndex device);
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
} // namespace c10::xpu::XPUCachingAllocator

View File

@ -76,6 +76,7 @@
:nosignatures:
empty_cache
get_per_process_memory_fraction
max_memory_allocated
max_memory_reserved
mem_get_info

View File

@ -489,6 +489,7 @@ if __name__ == "__main__":
torch.xpu.empty_cache()
total_memory = torch.xpu.get_device_properties().total_memory
fraction = 0.5
orig_fraction = torch.xpu.get_per_process_memory_fraction()
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
torch.xpu.set_per_process_memory_fraction(-0.1)
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
@ -503,11 +504,13 @@ if __name__ == "__main__":
gc.collect()
torch.xpu.empty_cache()
self.assertEqual(fraction, torch.xpu.get_per_process_memory_fraction())
application_memory = int(total_memory * 0.51)
with self.assertRaises(torch.OutOfMemoryError):
_ = torch.empty(application_memory, dtype=torch.int8, device="xpu")
torch.xpu.set_per_process_memory_fraction(1.0)
torch.xpu.set_per_process_memory_fraction(orig_fraction)
def test_memory_allocation(self):
torch.xpu.empty_cache()

View File

@ -2398,6 +2398,7 @@ def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _xpu_resetPeakMemoryStats(device: _int) -> None: ...
def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ...
def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ...
def _xpu_getMemoryFraction(device: _int) -> _float: ...
def _xpu_setMemoryFraction(fraction: _float, device: _int) -> None: ...
class _XpuDeviceProperties:

View File

@ -420,6 +420,9 @@ static void initXpuMethodBindings(PyObject* module) {
[](c10::DeviceIndex device, c10::DeviceIndex peer) {
return at::xpu::canDeviceAccessPeer(device, peer);
});
m.def("_xpu_getMemoryFraction", [](c10::DeviceIndex device) {
return c10::xpu::XPUCachingAllocator::getMemoryFraction(device);
});
m.def("_xpu_setMemoryFraction", [](double fraction, c10::DeviceIndex device) {
c10::xpu::XPUCachingAllocator::setMemoryFraction(fraction, device);
});

View File

@ -521,6 +521,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
# import here to avoid circular import
from .memory import (
empty_cache,
get_per_process_memory_fraction,
max_memory_allocated,
max_memory_reserved,
mem_get_info,
@ -562,6 +563,7 @@ __all__ = [
"get_device_name",
"get_device_properties",
"get_gencode_flags",
"get_per_process_memory_fraction",
"get_rng_state",
"get_rng_state_all",
"get_stream_from_external",

View File

@ -194,6 +194,26 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
return torch._C._xpu_getMemoryInfo(device)
def get_per_process_memory_fraction(device: _device_t = None) -> float:
r"""
Retrieve the memory fraction currently set for a process on a given XPU device.
This fraction represents the portion of the total device memory that
the caching allocator is allowed to use. The allowed memory is calculated as:
.. math:: \text{allowed\_memory} = \text{total\_memory} \times \text{fraction}
Args:
device (torch.device or int or str, optional): selected device. It uses the current device,
given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default).
Returns:
float: The memory fraction in the range 0.0 to 1.0.
"""
_lazy_init()
device = _get_device_index(device, optional=True)
return torch._C._xpu_getMemoryFraction(device)
def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -> None:
r"""
Set the memory fraction for a single process on XPU device.
@ -222,6 +242,7 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -
__all__ = [
"empty_cache",
"get_per_process_memory_fraction",
"max_memory_allocated",
"max_memory_reserved",
"mem_get_info",