mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
- This PR is a prerequisite for the upcoming Memory Leak Detection PR. - Enable global manual seeding via `torch.manual_seed()` + test case - Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]` - Added some test cases in test_mps.py - Added `mps.rst` to document the `torch.mps` module. - Fixed the failure with `test_public_bindings.py` Description of new files added: - `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`. - `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94417 Approved by: https://github.com/albanD
69 lines
1.7 KiB
C++
69 lines
1.7 KiB
C++
#include <ATen/ATen.h>
|
|
#include <torch/csrc/Generator.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
|
|
namespace torch {
|
|
namespace mps {
|
|
|
|
static PyObject* MPSModule_getDefaultMPSGenerator(
|
|
PyObject* _unused,
|
|
PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
return THPGenerator_initDefaultGenerator(
|
|
at::detail::getMPSHooks().getDefaultMPSGenerator());
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
if (at::detail::getMPSHooks().hasMPS()) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject* MPSModule_isMacOS13orNewer(
|
|
PyObject* _unused,
|
|
PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
at::detail::getMPSHooks().deviceSynchronize();
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
|
// cppcoreguidelines-avoid-non-const-global-variables,
|
|
// cppcoreguidelines-avoid-c-arrays)
|
|
static struct PyMethodDef _MPSModule_methods[] = {
|
|
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
|
|
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
|
{"_is_mps_on_macos_13_or_newer",
|
|
MPSModule_isMacOS13orNewer,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{"_mps_get_default_generator",
|
|
MPSModule_getDefaultMPSGenerator,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{nullptr}};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return _MPSModule_methods;
|
|
}
|
|
|
|
} // namespace mps
|
|
} // namespace torch
|