Revert "[MPS] Add Python Module Bindings for the MPS backend (#94417)"

This reverts commit beb4f5bf39.

Reverted https://github.com/pytorch/pytorch/pull/94417 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but it seems to break MacOS test in trunk bae397ec63
This commit is contained in:
PyTorch MergeBot 2023-02-11 05:24:44 +00:00
parent 77d9e36b0a
commit 4fe365774a
15 changed files with 16 additions and 225 deletions

View File

@ -28,10 +28,6 @@ struct TORCH_API MPSHooksInterface {
return false;
}
virtual bool isOnMacOS13orNewer() const {
AT_ERROR("MPS backend is not available.");
}
virtual const Generator& getDefaultMPSGenerator() const {
AT_ERROR("Cannot get default MPS generator without MPS backend.");
}
@ -39,10 +35,6 @@ struct TORCH_API MPSHooksInterface {
virtual Allocator* getMPSDeviceAllocator() const {
AT_ERROR("MPSDeviceAllocator requires MPS.");
}
virtual void deviceSynchronize() const {
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
}
};
struct TORCH_API MPSHooksArgs {};

View File

@ -79,7 +79,7 @@ class TORCH_API MPSDevice {
TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
TORCH_API void device_synchronize();
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
} // namespace mps

View File

@ -3,7 +3,6 @@
#include <c10/util/CallOnce.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/IndexKernels.h>
@ -123,9 +122,5 @@ bool is_macos_13_or_newer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}
void device_synchronize() {
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}
} // namespace mps
} // namespace at

View File

@ -16,10 +16,6 @@ bool MPSHooks::hasMPS() const {
return at::mps::is_available();
}
bool MPSHooks::isOnMacOS13orNewer() const {
return at::mps::is_macos_13_or_newer();
}
Allocator* MPSHooks::getMPSDeviceAllocator() const {
return at::mps::GetMPSAllocator();
}
@ -28,10 +24,6 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
return at::mps::detail::getDefaultMPSGenerator();
}
void MPSHooks::deviceSynchronize() const {
at::mps::device_synchronize();
}
using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;

View File

@ -13,10 +13,8 @@ struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
bool hasMPS() const override;
bool isOnMacOS13orNewer() const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;
};
}} // at::mps

View File

@ -822,7 +822,6 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -81,7 +81,6 @@ Features described in this documentation are classified by release status:
torch.autograd <autograd>
torch.library <library>
cuda
mps
torch.backends <backends>
torch.distributed <distributed>
torch.distributed.algorithms.join <distributed.algorithms.join>

View File

@ -1,14 +0,0 @@
torch.mps
===================================
.. automodule:: torch.mps
.. currentmodule:: torch.mps
.. autosummary::
:toctree: generated
:nosignatures:
synchronize
get_rng_state
set_rng_state
manual_seed
seed

View File

@ -5853,45 +5853,6 @@ class TestNLLLoss(TestCase):
mps_x = torch.randn(5, device='mps', generator=g_mps)
self.assertEqual(mps_x, mps_y)
def test_default_mps_generator(self):
# manual seeding on the "default" MPS generator using
# the global torch.manual_seed()
torch.manual_seed(230)
mps_x = torch.randn(5, device='mps')
# manual seeding using torch.mps.manual_seed()
# which should set the "default" MPS generator
# like the global torch.manual_seed()
torch.mps.manual_seed(230)
mps_y = torch.randn(5, device='mps')
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)
# save the default generator's state to restore it later
g_state = torch.mps.get_rng_state()
# generate random numbers without seeding
mps_x = torch.randn(5, device='mps')
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)
# restore the previously saved state, and the results should match again
torch.mps.set_rng_state(g_state)
mps_x = torch.randn(5, device='mps')
self.assertEqual(mps_x, mps_y)
def test_device_synchronize(self):
# just running some ops each followed by a synchronize to wait for
# MPS stream to finish running each of them
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
torch.mps.synchronize()
x = net1(x)
torch.mps.synchronize()
x.backward(torch.randn_like(x))
torch.mps.synchronize()
# Test random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):

View File

@ -903,6 +903,8 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
def _is_mps_available() -> _bool: ...
def _is_mps_on_macos_13_or_newer() -> _bool: ...
class _LinalgBackend:
Default: _LinalgBackend
Cusolver: _LinalgBackend
@ -1198,12 +1200,6 @@ class _TensorBase(metaclass=_TensorMeta):
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...
# Defined in torch/csrc/mps/Module.cpp
def _mps_synchronize() -> None: ...
def _mps_get_default_generator() -> Generator: ...
def _is_mps_available() -> _bool: ...
def _is_mps_on_macos_13_or_newer() -> _bool: ...
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
def _cuda_getCurrentRawStream(device: _int) -> _int: ...

View File

@ -60,7 +60,6 @@
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/lazy/python/init.h>
#include <torch/csrc/monitor/python_init.h>
#include <torch/csrc/mps/Module.h>
#include <torch/csrc/multiprocessing/init.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/profiler/python/init.h>
@ -88,6 +87,10 @@
#endif
#endif
#if defined(USE_MPS)
#include <ATen/mps/MPSDevice.h>
#endif
#if defined(USE_VALGRIND)
#include <callgrind.h>
#endif
@ -1268,7 +1271,6 @@ PyObject* initModule() {
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
#ifdef USE_CUDA
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
#endif
@ -1591,6 +1593,15 @@ Call this whenever a new thread is created in order to propagate values from
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
py_module.def("_is_mps_on_macos_13_or_newer", []() {
#ifdef USE_MPS
return at::mps::is_macos_13_or_newer();
#else
return false;
#endif
});
ASSERT_TRUE(
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));

View File

@ -1,68 +0,0 @@
#include <ATen/ATen.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/python_numbers.h>
namespace torch {
namespace mps {
static PyObject* MPSModule_getDefaultMPSGenerator(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultMPSGenerator());
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().hasMPS()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_isMacOS13orNewer(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().deviceSynchronize();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_is_mps_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,
METH_NOARGS,
nullptr},
{nullptr}};
PyMethodDef* python_functions() {
return _MPSModule_methods;
}
} // namespace mps
} // namespace torch

View File

@ -1,11 +0,0 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch {
namespace mps {
PyMethodDef* python_functions();
} // namespace mps
} // namespace torch

View File

@ -1,53 +0,0 @@
r"""
This package enables an interface for accessing MPS backend in python
"""
import torch
from .. import Tensor
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
# local helper function (not public or exported)
def _get_default_mps_generator() -> torch._C.Generator:
global _default_mps_generator
if _default_mps_generator is None:
_default_mps_generator = torch._C._mps_get_default_generator()
return _default_mps_generator
def synchronize() -> None:
r"""Waits for all kernels in all streams on a MPS device to complete."""
return torch._C._mps_synchronize()
def get_rng_state() -> Tensor:
r"""Returns the random number generator state as a ByteTensor."""
return _get_default_mps_generator().get_state()
def set_rng_state(new_state: Tensor) -> None:
r"""Sets the random number generator state.
Args:
new_state (torch.ByteTensor): The desired state
"""
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
_get_default_mps_generator().set_state(new_state_copy)
def manual_seed(seed: int) -> None:
r"""Sets the seed for generating random numbers.
Args:
seed (int): The desired seed.
"""
# the torch.mps.manual_seed() can be called from the global
# torch.manual_seed() in torch/random.py. So we need to make
# sure mps is available (otherwise we just return without
# erroring out)
if not torch._C._is_mps_available():
return
seed = int(seed)
_get_default_mps_generator().manual_seed(seed)
def seed() -> None:
r"""Sets the seed for generating random numbers to a random number."""
_get_default_mps_generator().seed()
__all__ = [
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']

View File

@ -39,9 +39,6 @@ def manual_seed(seed) -> torch._C.Generator:
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
torch.mps.manual_seed(seed)
return default_generator.manual_seed(seed)
@ -55,9 +52,6 @@ def seed() -> int:
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
torch.mps.manual_seed(seed)
return seed