mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add optional minor argument to is_macos13_or_newer (#95065)
Will be needed if one wants to make accurate XFAIL validation I.e. `torch.backends.mps.is_macos13_or_newer()` will return True if PyTorch is running on MacOS 13.0 or newer, `torch.backends.mps.is_macos13_or_newer(1)` will return True if running on MacOS 13.1 or newer and `torch.backends.mps.is_macos13_or_newer(2)` will return True if running on MacOS 13.2 or newer Do not use 13.3 check as `@available` does not really work for shared libraries Pull Request resolved: https://github.com/pytorch/pytorch/pull/95065 Approved by: https://github.com/albanD
This commit is contained in:
parent
c43e88665a
commit
5de3ead712
|
|
@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface {
|
|||
return false;
|
||||
}
|
||||
|
||||
virtual bool isOnMacOS13orNewer() const {
|
||||
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
|
||||
AT_ERROR("MPS backend is not available.");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const {
|
|||
return at::mps::is_available();
|
||||
}
|
||||
|
||||
bool MPSHooks::isOnMacOS13orNewer() const {
|
||||
return at::mps::is_macos_13_or_newer();
|
||||
bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
|
||||
switch (minor) {
|
||||
case 0:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
|
||||
case 1:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
|
||||
case 2:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
default:
|
||||
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
}
|
||||
}
|
||||
|
||||
Allocator* MPSHooks::getMPSDeviceAllocator() const {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
MPSHooks(at::MPSHooksArgs) {}
|
||||
void initMPS() const override;
|
||||
bool hasMPS() const override;
|
||||
bool isOnMacOS13orNewer() const override;
|
||||
bool isOnMacOS13orNewer(unsigned minor) const override;
|
||||
Allocator* getMPSDeviceAllocator() const override;
|
||||
const Generator& getDefaultMPSGenerator() const override;
|
||||
void deviceSynchronize() const override;
|
||||
|
|
|
|||
|
|
@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
|||
def _mps_currentAllocatedMemory() -> _int: ...
|
||||
def _mps_driverAllocatedMemory() -> _int: ...
|
||||
def _mps_is_available() -> _bool: ...
|
||||
def _mps_is_on_macos_13_or_newer() -> _bool: ...
|
||||
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ def is_available() -> bool:
|
|||
|
||||
|
||||
@_lru_cache()
|
||||
def is_macos13_or_newer() -> bool:
|
||||
def is_macos13_or_newer(minor: int = 0) -> bool:
|
||||
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
|
||||
return torch._C._mps_is_on_macos_13_or_newer()
|
||||
return torch._C._mps_is_on_macos_13_or_newer(minor)
|
||||
|
||||
|
||||
# Register prims as implementation of var_mean and group_norm
|
||||
|
|
|
|||
|
|
@ -59,11 +59,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_isMacOS13orNewer(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
|
||||
THPUtils_assert(
|
||||
THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()");
|
||||
auto minor = THPUtils_unpackUInt32(args);
|
||||
if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
|
|
@ -124,7 +125,7 @@ static struct PyMethodDef _MPSModule_methods[] = {
|
|||
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
||||
{"_mps_is_on_macos_13_or_newer",
|
||||
MPSModule_isMacOS13orNewer,
|
||||
METH_NOARGS,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_mps_get_default_generator",
|
||||
MPSModule_getDefaultMPSGenerator,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user