[MTIA] Support torch.mtia.empty_cache() (#141533)

Summary: As title

Test Plan:
Passed a local unit test: `buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api`

https://www.internalfb.com/intern/testinfra/testrun/4785074861101240

Reviewed By: nautsimon

Differential Revision: D66481778

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141533
Approved by: https://github.com/nautsimon
This commit is contained in:
Hyunho Yeo 2024-11-28 02:24:19 +00:00 committed by PyTorch MergeBot
parent f35bb55256
commit d70b7029c8
5 changed files with 15 additions and 1 deletions

View File

@ -109,6 +109,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__); FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr; return nullptr;
} }
virtual void emptyCache() const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
}; };
struct TORCH_API MTIAHooksArgs {}; struct TORCH_API MTIAHooksArgs {};

View File

@ -20,6 +20,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
is_initialized is_initialized
memory_stats memory_stats
get_device_capability get_device_capability
empty_cache
set_device set_device
set_stream set_stream
stream stream

View File

@ -1789,7 +1789,7 @@ def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ... def _mtia_getDefaultStream(device: _int) -> Stream: ...
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ... def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ... def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ...
def _mtia_emptyCache() -> None: ...
# Defined in torch/csrc/mps/Module.cpp # Defined in torch/csrc/mps/Module.cpp
def _mps_deviceSynchronize() -> None: ... def _mps_deviceSynchronize() -> None: ...

View File

@ -85,6 +85,8 @@ void initModule(PyObject* module) {
at::detail::getMTIAHooks().getDeviceCapability(device_index); at::detail::getMTIAHooks().getDeviceCapability(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject); return py::reinterpret_steal<py::object>(raw_pyobject);
}); });
m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); });
} }
} // namespace torch::mtia } // namespace torch::mtia

View File

@ -175,6 +175,11 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]
return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True)) return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True))
def empty_cache() -> None:
r"""Empty the MTIA device cache."""
return torch._C._mtia_emptyCache()
def set_stream(stream: Stream): def set_stream(stream: Stream):
r"""Set the current stream.This is a wrapper API to set the stream. r"""Set the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream`` Usage of this function is discouraged in favor of the ``stream``
@ -333,6 +338,7 @@ __all__ = [
"default_stream", "default_stream",
"memory_stats", "memory_stats",
"get_device_capability", "get_device_capability",
"empty_cache",
"set_device", "set_device",
"set_stream", "set_stream",
"stream", "stream",