From 861945100ec5480a304c09843e4c7a0826941c1d Mon Sep 17 00:00:00 2001 From: Zizeng Meng Date: Sun, 27 Apr 2025 15:56:41 +0000 Subject: [PATCH] [Kineto] Enable OOM observer (#152160) Summary: # Context: When memory leak happens, it usually trigger the OOM in the later iterations. The snapshot of full iteration will be huge and hard to interpret. On CUDA side, they provide OOM observer which generates snapshot when OOM happens with latest 1,500,000 entries for debugging. In this diff, we want to implement the feature on MTIA side Test Plan: Run this test with last diff in the stack. ``` buck run @//mode/opt kineto/libkineto/fb/mtia/integration_tests:mtia_memory_auto_trace_test ``` As shown, the memory_snapshot is generated when oom happens Log: P1794792326 Snapshot: https://fburl.com/pytorch_memory_visualizer/lx73y6s3 {F1977402355} Differential Revision: D71993315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152160 Approved by: https://github.com/sraikund16 --- aten/src/ATen/detail/MTIAHooksInterface.h | 4 ++++ docs/source/mtia.rst | 1 + torch/_C/__init__.pyi.in | 3 +++ torch/csrc/mtia/Module.cpp | 5 +++++ torch/mtia/__init__.py | 8 ++++++++ 5 files changed, 21 insertions(+) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index ba2b537dba2..218376db225 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -141,6 +141,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); } + virtual void attachOutOfMemoryObserver(PyObject* observer) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return; + } }; struct TORCH_API MTIAHooksArgs {}; diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index 7f625ebfee2..7572e6bf56b 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -23,6 +23,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined empty_cache record_memory_history snapshot + attach_out_of_memory_observer set_device set_stream stream diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3bbd7f628ba..149df10195d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1868,6 +1868,9 @@ def _mtia_recordMemoryHistory( max_entries ) -> None: ... def _mtia_memorySnapshot() -> Dict[str, Any]: ... +def _mtia_attachOutOfMemoryObserver( + observer: Callable[[_int, _int, _int, _int], None] +) -> None: ... def _mtia_getDeviceCount() -> _int: ... def _mtia_resetPeakMemoryStats(device: _int) -> None: ... diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index ec6229967e0..1ea6c6396f1 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -100,6 +100,11 @@ void initModule(PyObject* module) { return py::reinterpret_steal(raw_pyobject); }); + m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) { + at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr()); + return; + }); + m.def("_mtia_getDeviceCount", []() { return at::detail::getMTIAHooks().deviceCount(); }); diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 759093257a8..da66159640e 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -197,6 +197,13 @@ def snapshot() -> dict[str, Any]: return torch._C._mtia_memorySnapshot() +def attach_out_of_memory_observer( + observer: Callable[[int, int, int, int], None] +) -> None: + r"""Attach an out-of-memory observer to MTIA memory allocator""" + torch._C._mtia_attachOutOfMemoryObserver(observer) + + def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]: r"""Return capability of a given device as a tuple of (major version, minor version). @@ -378,6 +385,7 @@ __all__ = [ "get_device_capability", "record_memory_history", "snapshot", + "attach_out_of_memory_observer", "empty_cache", "set_device", "set_stream",