mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[MPS] Add Python Module Bindings for the MPS backend (#94417)"
This reverts commitbeb4f5bf39. Reverted https://github.com/pytorch/pytorch/pull/94417 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but it seems to break MacOS test in trunkbae397ec63
This commit is contained in:
parent
77d9e36b0a
commit
4fe365774a
|
|
@ -28,10 +28,6 @@ struct TORCH_API MPSHooksInterface {
|
|||
return false;
|
||||
}
|
||||
|
||||
virtual bool isOnMacOS13orNewer() const {
|
||||
AT_ERROR("MPS backend is not available.");
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultMPSGenerator() const {
|
||||
AT_ERROR("Cannot get default MPS generator without MPS backend.");
|
||||
}
|
||||
|
|
@ -39,10 +35,6 @@ struct TORCH_API MPSHooksInterface {
|
|||
virtual Allocator* getMPSDeviceAllocator() const {
|
||||
AT_ERROR("MPSDeviceAllocator requires MPS.");
|
||||
}
|
||||
|
||||
virtual void deviceSynchronize() const {
|
||||
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MPSHooksArgs {};
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class TORCH_API MPSDevice {
|
|||
|
||||
TORCH_API bool is_available();
|
||||
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
|
||||
TORCH_API void device_synchronize();
|
||||
|
||||
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
||||
|
||||
} // namespace mps
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/IndexKernels.h>
|
||||
|
||||
|
|
@ -123,9 +122,5 @@ bool is_macos_13_or_newer(MacOSVersion version) {
|
|||
return MPSDevice::getInstance()->isMacOS13Plus(version);
|
||||
}
|
||||
|
||||
void device_synchronize() {
|
||||
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -16,10 +16,6 @@ bool MPSHooks::hasMPS() const {
|
|||
return at::mps::is_available();
|
||||
}
|
||||
|
||||
bool MPSHooks::isOnMacOS13orNewer() const {
|
||||
return at::mps::is_macos_13_or_newer();
|
||||
}
|
||||
|
||||
Allocator* MPSHooks::getMPSDeviceAllocator() const {
|
||||
return at::mps::GetMPSAllocator();
|
||||
}
|
||||
|
|
@ -28,10 +24,6 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
|
|||
return at::mps::detail::getDefaultMPSGenerator();
|
||||
}
|
||||
|
||||
void MPSHooks::deviceSynchronize() const {
|
||||
at::mps::device_synchronize();
|
||||
}
|
||||
|
||||
using at::MPSHooksRegistry;
|
||||
using at::RegistererMPSHooksRegistry;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
MPSHooks(at::MPSHooksArgs) {}
|
||||
void initMPS() const override;
|
||||
bool hasMPS() const override;
|
||||
bool isOnMacOS13orNewer() const override;
|
||||
Allocator* getMPSDeviceAllocator() const override;
|
||||
const Generator& getDefaultMPSGenerator() const override;
|
||||
void deviceSynchronize() const override;
|
||||
};
|
||||
|
||||
}} // at::mps
|
||||
|
|
|
|||
|
|
@ -822,7 +822,6 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ Features described in this documentation are classified by release status:
|
|||
torch.autograd <autograd>
|
||||
torch.library <library>
|
||||
cuda
|
||||
mps
|
||||
torch.backends <backends>
|
||||
torch.distributed <distributed>
|
||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
torch.mps
|
||||
===================================
|
||||
.. automodule:: torch.mps
|
||||
.. currentmodule:: torch.mps
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
synchronize
|
||||
get_rng_state
|
||||
set_rng_state
|
||||
manual_seed
|
||||
seed
|
||||
|
|
@ -5853,45 +5853,6 @@ class TestNLLLoss(TestCase):
|
|||
mps_x = torch.randn(5, device='mps', generator=g_mps)
|
||||
self.assertEqual(mps_x, mps_y)
|
||||
|
||||
def test_default_mps_generator(self):
|
||||
# manual seeding on the "default" MPS generator using
|
||||
# the global torch.manual_seed()
|
||||
torch.manual_seed(230)
|
||||
mps_x = torch.randn(5, device='mps')
|
||||
# manual seeding using torch.mps.manual_seed()
|
||||
# which should set the "default" MPS generator
|
||||
# like the global torch.manual_seed()
|
||||
torch.mps.manual_seed(230)
|
||||
mps_y = torch.randn(5, device='mps')
|
||||
# seed values were the same, so the random tensor contents should match
|
||||
self.assertEqual(mps_x, mps_y)
|
||||
|
||||
# save the default generator's state to restore it later
|
||||
g_state = torch.mps.get_rng_state()
|
||||
|
||||
# generate random numbers without seeding
|
||||
mps_x = torch.randn(5, device='mps')
|
||||
# in this case, the random results must differ from the last generated random results
|
||||
self.assertNotEqual(mps_x, mps_y)
|
||||
|
||||
# restore the previously saved state, and the results should match again
|
||||
torch.mps.set_rng_state(g_state)
|
||||
mps_x = torch.randn(5, device='mps')
|
||||
self.assertEqual(mps_x, mps_y)
|
||||
|
||||
def test_device_synchronize(self):
|
||||
# just running some ops each followed by a synchronize to wait for
|
||||
# MPS stream to finish running each of them
|
||||
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
|
||||
.to(device='mps', dtype=torch.float)
|
||||
|
||||
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
|
||||
torch.mps.synchronize()
|
||||
x = net1(x)
|
||||
torch.mps.synchronize()
|
||||
x.backward(torch.randn_like(x))
|
||||
torch.mps.synchronize()
|
||||
|
||||
# Test random_.to and random_.from
|
||||
def test_random(self):
|
||||
def helper(shape, low, high, dtype=torch.int32):
|
||||
|
|
|
|||
|
|
@ -903,6 +903,8 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
|
|||
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
|
||||
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
|
||||
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
|
||||
def _is_mps_available() -> _bool: ...
|
||||
def _is_mps_on_macos_13_or_newer() -> _bool: ...
|
||||
class _LinalgBackend:
|
||||
Default: _LinalgBackend
|
||||
Cusolver: _LinalgBackend
|
||||
|
|
@ -1198,12 +1200,6 @@ class _TensorBase(metaclass=_TensorMeta):
|
|||
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||
def _multiprocessing_init() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/mps/Module.cpp
|
||||
def _mps_synchronize() -> None: ...
|
||||
def _mps_get_default_generator() -> Generator: ...
|
||||
def _is_mps_available() -> _bool: ...
|
||||
def _is_mps_on_macos_13_or_newer() -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
|
||||
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@
|
|||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/lazy/python/init.h>
|
||||
#include <torch/csrc/monitor/python_init.h>
|
||||
#include <torch/csrc/mps/Module.h>
|
||||
#include <torch/csrc/multiprocessing/init.h>
|
||||
#include <torch/csrc/onnx/init.h>
|
||||
#include <torch/csrc/profiler/python/init.h>
|
||||
|
|
@ -88,6 +87,10 @@
|
|||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_MPS)
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_VALGRIND)
|
||||
#include <callgrind.h>
|
||||
#endif
|
||||
|
|
@ -1268,7 +1271,6 @@ PyObject* initModule() {
|
|||
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
|
||||
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
|
||||
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
|
||||
THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
|
||||
#ifdef USE_CUDA
|
||||
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
|
||||
#endif
|
||||
|
|
@ -1591,6 +1593,15 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
|
||||
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
|
||||
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
|
||||
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
|
||||
py_module.def("_is_mps_on_macos_13_or_newer", []() {
|
||||
#ifdef USE_MPS
|
||||
return at::mps::is_macos_13_or_newer();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
|
||||
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace mps {
|
||||
|
||||
static PyObject* MPSModule_getDefaultMPSGenerator(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPGenerator_initDefaultGenerator(
|
||||
at::detail::getMPSHooks().getDefaultMPSGenerator());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::detail::getMPSHooks().hasMPS()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_isMacOS13orNewer(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::detail::getMPSHooks().deviceSynchronize();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static struct PyMethodDef _MPSModule_methods[] = {
|
||||
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
|
||||
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
||||
{"_is_mps_on_macos_13_or_newer",
|
||||
MPSModule_isMacOS13orNewer,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_mps_get_default_generator",
|
||||
MPSModule_getDefaultMPSGenerator,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyMethodDef* python_functions() {
|
||||
return _MPSModule_methods;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
} // namespace torch
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace mps {
|
||||
|
||||
PyMethodDef* python_functions();
|
||||
|
||||
} // namespace mps
|
||||
} // namespace torch
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
r"""
|
||||
This package enables an interface for accessing MPS backend in python
|
||||
"""
|
||||
import torch
|
||||
from .. import Tensor
|
||||
|
||||
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
|
||||
|
||||
# local helper function (not public or exported)
|
||||
def _get_default_mps_generator() -> torch._C.Generator:
|
||||
global _default_mps_generator
|
||||
if _default_mps_generator is None:
|
||||
_default_mps_generator = torch._C._mps_get_default_generator()
|
||||
return _default_mps_generator
|
||||
|
||||
def synchronize() -> None:
|
||||
r"""Waits for all kernels in all streams on a MPS device to complete."""
|
||||
return torch._C._mps_synchronize()
|
||||
|
||||
def get_rng_state() -> Tensor:
|
||||
r"""Returns the random number generator state as a ByteTensor."""
|
||||
return _get_default_mps_generator().get_state()
|
||||
|
||||
def set_rng_state(new_state: Tensor) -> None:
|
||||
r"""Sets the random number generator state.
|
||||
|
||||
Args:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
"""
|
||||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
||||
_get_default_mps_generator().set_state(new_state_copy)
|
||||
|
||||
def manual_seed(seed: int) -> None:
|
||||
r"""Sets the seed for generating random numbers.
|
||||
|
||||
Args:
|
||||
seed (int): The desired seed.
|
||||
"""
|
||||
# the torch.mps.manual_seed() can be called from the global
|
||||
# torch.manual_seed() in torch/random.py. So we need to make
|
||||
# sure mps is available (otherwise we just return without
|
||||
# erroring out)
|
||||
if not torch._C._is_mps_available():
|
||||
return
|
||||
seed = int(seed)
|
||||
_get_default_mps_generator().manual_seed(seed)
|
||||
|
||||
def seed() -> None:
|
||||
r"""Sets the seed for generating random numbers to a random number."""
|
||||
_get_default_mps_generator().seed()
|
||||
|
||||
__all__ = [
|
||||
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']
|
||||
|
|
@ -39,9 +39,6 @@ def manual_seed(seed) -> torch._C.Generator:
|
|||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
return default_generator.manual_seed(seed)
|
||||
|
||||
|
||||
|
|
@ -55,9 +52,6 @@ def seed() -> int:
|
|||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user