mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Introduce a new API torch.xpu.get_per_process_memory_fraction (#165511)
# Motivation Aligned with other backends, this PR introduces a new API torch.xpu.get_per_process_memory_fraction to allow user to retrieve the allowed memory fraction per a single process. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165511 Approved by: https://github.com/EikanWang, https://github.com/ezyang ghstack dependencies: #165508, #165509, #165510
This commit is contained in:
parent
8221ee6db9
commit
0ec0549823
|
|
@ -554,6 +554,17 @@ class DeviceCachingAllocator {
|
|||
}
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
|
|
@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator {
|
|||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
TORCH_CHECK_VALUE(
|
||||
|
|
@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
|||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
return allocator.setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr);
|
|||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
} // namespace c10::xpu::XPUCachingAllocator
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@
|
|||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_per_process_memory_fraction
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
mem_get_info
|
||||
|
|
|
|||
|
|
@ -489,6 +489,7 @@ if __name__ == "__main__":
|
|||
torch.xpu.empty_cache()
|
||||
total_memory = torch.xpu.get_device_properties().total_memory
|
||||
fraction = 0.5
|
||||
orig_fraction = torch.xpu.get_per_process_memory_fraction()
|
||||
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
|
||||
torch.xpu.set_per_process_memory_fraction(-0.1)
|
||||
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
|
||||
|
|
@ -503,11 +504,13 @@ if __name__ == "__main__":
|
|||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
self.assertEqual(fraction, torch.xpu.get_per_process_memory_fraction())
|
||||
|
||||
application_memory = int(total_memory * 0.51)
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
_ = torch.empty(application_memory, dtype=torch.int8, device="xpu")
|
||||
|
||||
torch.xpu.set_per_process_memory_fraction(1.0)
|
||||
torch.xpu.set_per_process_memory_fraction(orig_fraction)
|
||||
|
||||
def test_memory_allocation(self):
|
||||
torch.xpu.empty_cache()
|
||||
|
|
|
|||
|
|
@ -2398,6 +2398,7 @@ def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
|||
def _xpu_resetPeakMemoryStats(device: _int) -> None: ...
|
||||
def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ...
|
||||
def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ...
|
||||
def _xpu_getMemoryFraction(device: _int) -> _float: ...
|
||||
def _xpu_setMemoryFraction(fraction: _float, device: _int) -> None: ...
|
||||
|
||||
class _XpuDeviceProperties:
|
||||
|
|
|
|||
|
|
@ -420,6 +420,9 @@ static void initXpuMethodBindings(PyObject* module) {
|
|||
[](c10::DeviceIndex device, c10::DeviceIndex peer) {
|
||||
return at::xpu::canDeviceAccessPeer(device, peer);
|
||||
});
|
||||
m.def("_xpu_getMemoryFraction", [](c10::DeviceIndex device) {
|
||||
return c10::xpu::XPUCachingAllocator::getMemoryFraction(device);
|
||||
});
|
||||
m.def("_xpu_setMemoryFraction", [](double fraction, c10::DeviceIndex device) {
|
||||
c10::xpu::XPUCachingAllocator::setMemoryFraction(fraction, device);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -521,6 +521,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
|
|||
# import here to avoid circular import
|
||||
from .memory import (
|
||||
empty_cache,
|
||||
get_per_process_memory_fraction,
|
||||
max_memory_allocated,
|
||||
max_memory_reserved,
|
||||
mem_get_info,
|
||||
|
|
@ -562,6 +563,7 @@ __all__ = [
|
|||
"get_device_name",
|
||||
"get_device_properties",
|
||||
"get_gencode_flags",
|
||||
"get_per_process_memory_fraction",
|
||||
"get_rng_state",
|
||||
"get_rng_state_all",
|
||||
"get_stream_from_external",
|
||||
|
|
|
|||
|
|
@ -194,6 +194,26 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
|||
return torch._C._xpu_getMemoryInfo(device)
|
||||
|
||||
|
||||
def get_per_process_memory_fraction(device: _device_t = None) -> float:
|
||||
r"""
|
||||
Retrieve the memory fraction currently set for a process on a given XPU device.
|
||||
This fraction represents the portion of the total device memory that
|
||||
the caching allocator is allowed to use. The allowed memory is calculated as:
|
||||
|
||||
.. math:: \text{allowed\_memory} = \text{total\_memory} \times \text{fraction}
|
||||
|
||||
Args:
|
||||
device (torch.device or int or str, optional): selected device. It uses the current device,
|
||||
given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default).
|
||||
|
||||
Returns:
|
||||
float: The memory fraction in the range 0.0 to 1.0.
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch._C._xpu_getMemoryFraction(device)
|
||||
|
||||
|
||||
def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -> None:
|
||||
r"""
|
||||
Set the memory fraction for a single process on XPU device.
|
||||
|
|
@ -222,6 +242,7 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -
|
|||
|
||||
__all__ = [
|
||||
"empty_cache",
|
||||
"get_per_process_memory_fraction",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
"mem_get_info",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user