[MPS] Add optional minor argument to is_macos13_or_newer (#95065)

Will be needed if one wants to make accurate XFAIL validation

I.e. `torch.backends.mps.is_macos13_or_newer()` will return True if PyTorch is running on MacOS 13.0 or newer, `torch.backends.mps.is_macos13_or_newer(1)` will return True if running on MacOS 13.1 or newer and `torch.backends.mps.is_macos13_or_newer(2)` will return True  if running on MacOS 13.2 or newer

Do not use 13.3 check as `@available` does not really work for shared libraries

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95065
Approved by: https://github.com/albanD
This commit is contained in:
Nikita Shulga 2023-02-17 18:30:20 +00:00 committed by PyTorch MergeBot
parent c43e88665a
commit 5de3ead712
6 changed files with 23 additions and 12 deletions

View File

@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface {
return false;
}
virtual bool isOnMacOS13orNewer() const {
virtual bool isOnMacOS13orNewer(unsigned minor = 0) const {
AT_ERROR("MPS backend is not available.");
}

View File

@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const {
return at::mps::is_available();
}
bool MPSHooks::isOnMacOS13orNewer() const {
return at::mps::is_macos_13_or_newer();
bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const {
switch (minor) {
case 0:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
}
}
Allocator* MPSHooks::getMPSDeviceAllocator() const {

View File

@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
bool hasMPS() const override;
bool isOnMacOS13orNewer() const override;
bool isOnMacOS13orNewer(unsigned minor) const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;

View File

@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ...
def _mps_currentAllocatedMemory() -> _int: ...
def _mps_driverAllocatedMemory() -> _int: ...
def _mps_is_available() -> _bool: ...
def _mps_is_on_macos_13_or_newer() -> _bool: ...
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...

View File

@ -19,9 +19,9 @@ def is_available() -> bool:
@_lru_cache()
def is_macos13_or_newer() -> bool:
def is_macos13_or_newer(minor: int = 0) -> bool:
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
return torch._C._mps_is_on_macos_13_or_newer()
return torch._C._mps_is_on_macos_13_or_newer(minor)
# Register prims as implementation of var_mean and group_norm

View File

@ -59,11 +59,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_isMacOS13orNewer(
PyObject* _unused,
PyObject* noargs) {
static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
THPUtils_assert(
THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()");
auto minor = THPUtils_unpackUInt32(args);
if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -124,7 +125,7 @@ static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
METH_O,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,