pytorch/torch/mtia/memory.py
Hyunho Yeo 001ebbf734 [MTIA] (4/n) Implement PyTorch APIs to query/reset device peak memory usage (#146751)
Summary: Public summary (shared with Github): This diff updates the unit test for the PyTorch API "reset_peak_memory_stats".

Test Plan:
```
buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api -- -r test_reset_peak_memory_stats
```

https://www.internalfb.com/intern/testinfra/testrun/9007199321947161

Reviewed By: yuhc

Differential Revision: D68989900

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146751
Approved by: https://github.com/nautsimon
2025-02-11 03:51:48 +00:00

58 lines
1.7 KiB
Python

# pyre-strict
r"""This package adds support for device memory management implemented in MTIA."""
from typing import Any, Optional
import torch
from . import _device_t, is_initialized
from ._utils import _get_device_index
def memory_stats(device: Optional[_device_t] = None) -> dict[str, Any]:
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return {}
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
def max_memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the maximum memory allocated in bytes for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
r"""Reset the peak memory stats for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return
torch._C._mtia_resetPeakMemoryStats(_get_device_index(device, optional=True))
__all__ = [
"memory_stats",
"max_memory_allocated",
"reset_peak_memory_stats",
]