[MTIA] Support torch.cuda.get_device_capability equivalent API on MTIA (#135889)

Summary:
Mirror `get_device_capability` on MTIA per https://fburl.com/gdoc/p4lo5avn

At the moment, both the major and minor version are just 0

Test Plan:
Unit test: `buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api`

https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688850109958190/

Differential Revision: D62595296

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135889
Approved by: https://github.com/egienvalue
This commit is contained in:
Trung Truong 2024-09-17 17:42:56 +00:00 committed by PyTorch MergeBot
parent 8e5bb356e0
commit cc365fdd7b
5 changed files with 25 additions and 0 deletions

View File

@ -104,6 +104,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual PyObject* getDeviceCapability(DeviceIndex device) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -19,6 +19,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
is_available
is_initialized
memory_stats
get_device_capability
set_device
set_stream
stream

View File

@ -1784,6 +1784,7 @@ def _mtia_getCurrentStream(device: _int) -> Stream: ...
def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ...
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ...
# Defined in torch/csrc/mps/Module.cpp

View File

@ -80,6 +80,12 @@ void initModule(PyObject* module) {
at::detail::getMTIAHooks().memoryStats(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceCapability(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
}
} // namespace mtia

View File

@ -166,6 +166,17 @@ def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
r"""Return capability of a given device as a tuple of (major version, minor version).
Args:
device (torch.device or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True))
def set_stream(stream: Stream):
r"""Set the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
@ -323,6 +334,7 @@ __all__ = [
"current_stream",
"default_stream",
"memory_stats",
"get_device_capability",
"set_device",
"set_stream",
"stream",