mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MTIA] Support torch.mtia.empty_cache() (#141533)
Summary: As title Test Plan: Passed a local unit test: `buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api` https://www.internalfb.com/intern/testinfra/testrun/4785074861101240 Reviewed By: nautsimon Differential Revision: D66481778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141533 Approved by: https://github.com/nautsimon
This commit is contained in:
parent
f35bb55256
commit
d70b7029c8
|
|
@ -109,6 +109,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual void emptyCache() const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API MTIAHooksArgs {};
|
struct TORCH_API MTIAHooksArgs {};
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
||||||
is_initialized
|
is_initialized
|
||||||
memory_stats
|
memory_stats
|
||||||
get_device_capability
|
get_device_capability
|
||||||
|
empty_cache
|
||||||
set_device
|
set_device
|
||||||
set_stream
|
set_stream
|
||||||
stream
|
stream
|
||||||
|
|
|
||||||
|
|
@ -1789,7 +1789,7 @@ def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||||
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||||
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
|
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
|
||||||
def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ...
|
def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ...
|
||||||
|
def _mtia_emptyCache() -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/mps/Module.cpp
|
# Defined in torch/csrc/mps/Module.cpp
|
||||||
def _mps_deviceSynchronize() -> None: ...
|
def _mps_deviceSynchronize() -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,8 @@ void initModule(PyObject* module) {
|
||||||
at::detail::getMTIAHooks().getDeviceCapability(device_index);
|
at::detail::getMTIAHooks().getDeviceCapability(device_index);
|
||||||
return py::reinterpret_steal<py::object>(raw_pyobject);
|
return py::reinterpret_steal<py::object>(raw_pyobject);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::mtia
|
} // namespace torch::mtia
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,11 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]
|
||||||
return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True))
|
return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True))
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None:
|
||||||
|
r"""Empty the MTIA device cache."""
|
||||||
|
return torch._C._mtia_emptyCache()
|
||||||
|
|
||||||
|
|
||||||
def set_stream(stream: Stream):
|
def set_stream(stream: Stream):
|
||||||
r"""Set the current stream.This is a wrapper API to set the stream.
|
r"""Set the current stream.This is a wrapper API to set the stream.
|
||||||
Usage of this function is discouraged in favor of the ``stream``
|
Usage of this function is discouraged in favor of the ``stream``
|
||||||
|
|
@ -333,6 +338,7 @@ __all__ = [
|
||||||
"default_stream",
|
"default_stream",
|
||||||
"memory_stats",
|
"memory_stats",
|
||||||
"get_device_capability",
|
"get_device_capability",
|
||||||
|
"empty_cache",
|
||||||
"set_device",
|
"set_device",
|
||||||
"set_stream",
|
"set_stream",
|
||||||
"stream",
|
"stream",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user