mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[MPS] Add Python Module Bindings for the MPS backend (#94417)"
This reverts commitbeb4f5bf39. Reverted https://github.com/pytorch/pytorch/pull/94417 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but it seems to break MacOS test in trunkbae397ec63
This commit is contained in:
parent
77d9e36b0a
commit
4fe365774a
|
|
@ -28,10 +28,6 @@ struct TORCH_API MPSHooksInterface {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool isOnMacOS13orNewer() const {
|
|
||||||
AT_ERROR("MPS backend is not available.");
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const Generator& getDefaultMPSGenerator() const {
|
virtual const Generator& getDefaultMPSGenerator() const {
|
||||||
AT_ERROR("Cannot get default MPS generator without MPS backend.");
|
AT_ERROR("Cannot get default MPS generator without MPS backend.");
|
||||||
}
|
}
|
||||||
|
|
@ -39,10 +35,6 @@ struct TORCH_API MPSHooksInterface {
|
||||||
virtual Allocator* getMPSDeviceAllocator() const {
|
virtual Allocator* getMPSDeviceAllocator() const {
|
||||||
AT_ERROR("MPSDeviceAllocator requires MPS.");
|
AT_ERROR("MPSDeviceAllocator requires MPS.");
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void deviceSynchronize() const {
|
|
||||||
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API MPSHooksArgs {};
|
struct TORCH_API MPSHooksArgs {};
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ class TORCH_API MPSDevice {
|
||||||
|
|
||||||
TORCH_API bool is_available();
|
TORCH_API bool is_available();
|
||||||
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
|
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
|
||||||
TORCH_API void device_synchronize();
|
|
||||||
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
||||||
|
|
||||||
} // namespace mps
|
} // namespace mps
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
#include <c10/util/CallOnce.h>
|
#include <c10/util/CallOnce.h>
|
||||||
|
|
||||||
#include <ATen/mps/MPSDevice.h>
|
#include <ATen/mps/MPSDevice.h>
|
||||||
#include <ATen/mps/MPSStream.h>
|
|
||||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||||
#include <ATen/mps/IndexKernels.h>
|
#include <ATen/mps/IndexKernels.h>
|
||||||
|
|
||||||
|
|
@ -123,9 +122,5 @@ bool is_macos_13_or_newer(MacOSVersion version) {
|
||||||
return MPSDevice::getInstance()->isMacOS13Plus(version);
|
return MPSDevice::getInstance()->isMacOS13Plus(version);
|
||||||
}
|
}
|
||||||
|
|
||||||
void device_synchronize() {
|
|
||||||
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mps
|
} // namespace mps
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,6 @@ bool MPSHooks::hasMPS() const {
|
||||||
return at::mps::is_available();
|
return at::mps::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MPSHooks::isOnMacOS13orNewer() const {
|
|
||||||
return at::mps::is_macos_13_or_newer();
|
|
||||||
}
|
|
||||||
|
|
||||||
Allocator* MPSHooks::getMPSDeviceAllocator() const {
|
Allocator* MPSHooks::getMPSDeviceAllocator() const {
|
||||||
return at::mps::GetMPSAllocator();
|
return at::mps::GetMPSAllocator();
|
||||||
}
|
}
|
||||||
|
|
@ -28,10 +24,6 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
|
||||||
return at::mps::detail::getDefaultMPSGenerator();
|
return at::mps::detail::getDefaultMPSGenerator();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MPSHooks::deviceSynchronize() const {
|
|
||||||
at::mps::device_synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
using at::MPSHooksRegistry;
|
using at::MPSHooksRegistry;
|
||||||
using at::RegistererMPSHooksRegistry;
|
using at::RegistererMPSHooksRegistry;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||||
MPSHooks(at::MPSHooksArgs) {}
|
MPSHooks(at::MPSHooksArgs) {}
|
||||||
void initMPS() const override;
|
void initMPS() const override;
|
||||||
bool hasMPS() const override;
|
bool hasMPS() const override;
|
||||||
bool isOnMacOS13orNewer() const override;
|
|
||||||
Allocator* getMPSDeviceAllocator() const override;
|
Allocator* getMPSDeviceAllocator() const override;
|
||||||
const Generator& getDefaultMPSGenerator() const override;
|
const Generator& getDefaultMPSGenerator() const override;
|
||||||
void deviceSynchronize() const override;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}} // at::mps
|
}} // at::mps
|
||||||
|
|
|
||||||
|
|
@ -822,7 +822,6 @@ libtorch_python_core_sources = [
|
||||||
"torch/csrc/dynamo/guards.cpp",
|
"torch/csrc/dynamo/guards.cpp",
|
||||||
"torch/csrc/dynamo/init.cpp",
|
"torch/csrc/dynamo/init.cpp",
|
||||||
"torch/csrc/functorch/init.cpp",
|
"torch/csrc/functorch/init.cpp",
|
||||||
"torch/csrc/mps/Module.cpp",
|
|
||||||
"torch/csrc/jit/backends/backend_init.cpp",
|
"torch/csrc/jit/backends/backend_init.cpp",
|
||||||
"torch/csrc/jit/python/init.cpp",
|
"torch/csrc/jit/python/init.cpp",
|
||||||
"torch/csrc/jit/passes/onnx.cpp",
|
"torch/csrc/jit/passes/onnx.cpp",
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,6 @@ Features described in this documentation are classified by release status:
|
||||||
torch.autograd <autograd>
|
torch.autograd <autograd>
|
||||||
torch.library <library>
|
torch.library <library>
|
||||||
cuda
|
cuda
|
||||||
mps
|
|
||||||
torch.backends <backends>
|
torch.backends <backends>
|
||||||
torch.distributed <distributed>
|
torch.distributed <distributed>
|
||||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
torch.mps
|
|
||||||
===================================
|
|
||||||
.. automodule:: torch.mps
|
|
||||||
.. currentmodule:: torch.mps
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
synchronize
|
|
||||||
get_rng_state
|
|
||||||
set_rng_state
|
|
||||||
manual_seed
|
|
||||||
seed
|
|
||||||
|
|
@ -5853,45 +5853,6 @@ class TestNLLLoss(TestCase):
|
||||||
mps_x = torch.randn(5, device='mps', generator=g_mps)
|
mps_x = torch.randn(5, device='mps', generator=g_mps)
|
||||||
self.assertEqual(mps_x, mps_y)
|
self.assertEqual(mps_x, mps_y)
|
||||||
|
|
||||||
def test_default_mps_generator(self):
|
|
||||||
# manual seeding on the "default" MPS generator using
|
|
||||||
# the global torch.manual_seed()
|
|
||||||
torch.manual_seed(230)
|
|
||||||
mps_x = torch.randn(5, device='mps')
|
|
||||||
# manual seeding using torch.mps.manual_seed()
|
|
||||||
# which should set the "default" MPS generator
|
|
||||||
# like the global torch.manual_seed()
|
|
||||||
torch.mps.manual_seed(230)
|
|
||||||
mps_y = torch.randn(5, device='mps')
|
|
||||||
# seed values were the same, so the random tensor contents should match
|
|
||||||
self.assertEqual(mps_x, mps_y)
|
|
||||||
|
|
||||||
# save the default generator's state to restore it later
|
|
||||||
g_state = torch.mps.get_rng_state()
|
|
||||||
|
|
||||||
# generate random numbers without seeding
|
|
||||||
mps_x = torch.randn(5, device='mps')
|
|
||||||
# in this case, the random results must differ from the last generated random results
|
|
||||||
self.assertNotEqual(mps_x, mps_y)
|
|
||||||
|
|
||||||
# restore the previously saved state, and the results should match again
|
|
||||||
torch.mps.set_rng_state(g_state)
|
|
||||||
mps_x = torch.randn(5, device='mps')
|
|
||||||
self.assertEqual(mps_x, mps_y)
|
|
||||||
|
|
||||||
def test_device_synchronize(self):
|
|
||||||
# just running some ops each followed by a synchronize to wait for
|
|
||||||
# MPS stream to finish running each of them
|
|
||||||
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
|
|
||||||
.to(device='mps', dtype=torch.float)
|
|
||||||
|
|
||||||
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
x = net1(x)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
x.backward(torch.randn_like(x))
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
# Test random_.to and random_.from
|
# Test random_.to and random_.from
|
||||||
def test_random(self):
|
def test_random(self):
|
||||||
def helper(shape, low, high, dtype=torch.int32):
|
def helper(shape, low, high, dtype=torch.int32):
|
||||||
|
|
|
||||||
|
|
@ -903,6 +903,8 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
|
||||||
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
|
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
|
||||||
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
|
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
|
||||||
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
|
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
|
||||||
|
def _is_mps_available() -> _bool: ...
|
||||||
|
def _is_mps_on_macos_13_or_newer() -> _bool: ...
|
||||||
class _LinalgBackend:
|
class _LinalgBackend:
|
||||||
Default: _LinalgBackend
|
Default: _LinalgBackend
|
||||||
Cusolver: _LinalgBackend
|
Cusolver: _LinalgBackend
|
||||||
|
|
@ -1198,12 +1200,6 @@ class _TensorBase(metaclass=_TensorMeta):
|
||||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||||
def _multiprocessing_init() -> None: ...
|
def _multiprocessing_init() -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/mps/Module.cpp
|
|
||||||
def _mps_synchronize() -> None: ...
|
|
||||||
def _mps_get_default_generator() -> Generator: ...
|
|
||||||
def _is_mps_available() -> _bool: ...
|
|
||||||
def _is_mps_on_macos_13_or_newer() -> _bool: ...
|
|
||||||
|
|
||||||
# Defined in torch/csrc/cuda/Module.cpp
|
# Defined in torch/csrc/cuda/Module.cpp
|
||||||
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
|
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
|
||||||
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,6 @@
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/csrc/lazy/python/init.h>
|
#include <torch/csrc/lazy/python/init.h>
|
||||||
#include <torch/csrc/monitor/python_init.h>
|
#include <torch/csrc/monitor/python_init.h>
|
||||||
#include <torch/csrc/mps/Module.h>
|
|
||||||
#include <torch/csrc/multiprocessing/init.h>
|
#include <torch/csrc/multiprocessing/init.h>
|
||||||
#include <torch/csrc/onnx/init.h>
|
#include <torch/csrc/onnx/init.h>
|
||||||
#include <torch/csrc/profiler/python/init.h>
|
#include <torch/csrc/profiler/python/init.h>
|
||||||
|
|
@ -88,6 +87,10 @@
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(USE_MPS)
|
||||||
|
#include <ATen/mps/MPSDevice.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(USE_VALGRIND)
|
#if defined(USE_VALGRIND)
|
||||||
#include <callgrind.h>
|
#include <callgrind.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -1268,7 +1271,6 @@ PyObject* initModule() {
|
||||||
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
|
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
|
||||||
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
|
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
|
||||||
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
|
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
|
||||||
THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
|
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -1591,6 +1593,15 @@ Call this whenever a new thread is created in order to propagate values from
|
||||||
|
|
||||||
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
|
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
|
||||||
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
|
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
|
||||||
|
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
|
||||||
|
py_module.def("_is_mps_on_macos_13_or_newer", []() {
|
||||||
|
#ifdef USE_MPS
|
||||||
|
return at::mps::is_macos_13_or_newer();
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
});
|
||||||
|
|
||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
|
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,68 +0,0 @@
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <torch/csrc/Generator.h>
|
|
||||||
#include <torch/csrc/python_headers.h>
|
|
||||||
#include <torch/csrc/utils/python_numbers.h>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
namespace mps {
|
|
||||||
|
|
||||||
static PyObject* MPSModule_getDefaultMPSGenerator(
|
|
||||||
PyObject* _unused,
|
|
||||||
PyObject* noargs) {
|
|
||||||
HANDLE_TH_ERRORS
|
|
||||||
return THPGenerator_initDefaultGenerator(
|
|
||||||
at::detail::getMPSHooks().getDefaultMPSGenerator());
|
|
||||||
END_HANDLE_TH_ERRORS
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
|
|
||||||
HANDLE_TH_ERRORS
|
|
||||||
if (at::detail::getMPSHooks().hasMPS()) {
|
|
||||||
Py_RETURN_TRUE;
|
|
||||||
} else {
|
|
||||||
Py_RETURN_FALSE;
|
|
||||||
}
|
|
||||||
END_HANDLE_TH_ERRORS
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* MPSModule_isMacOS13orNewer(
|
|
||||||
PyObject* _unused,
|
|
||||||
PyObject* noargs) {
|
|
||||||
HANDLE_TH_ERRORS
|
|
||||||
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
|
|
||||||
Py_RETURN_TRUE;
|
|
||||||
} else {
|
|
||||||
Py_RETURN_FALSE;
|
|
||||||
}
|
|
||||||
END_HANDLE_TH_ERRORS
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
|
|
||||||
HANDLE_TH_ERRORS
|
|
||||||
at::detail::getMPSHooks().deviceSynchronize();
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
END_HANDLE_TH_ERRORS
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
|
||||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
|
||||||
// cppcoreguidelines-avoid-c-arrays)
|
|
||||||
static struct PyMethodDef _MPSModule_methods[] = {
|
|
||||||
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
|
|
||||||
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
|
||||||
{"_is_mps_on_macos_13_or_newer",
|
|
||||||
MPSModule_isMacOS13orNewer,
|
|
||||||
METH_NOARGS,
|
|
||||||
nullptr},
|
|
||||||
{"_mps_get_default_generator",
|
|
||||||
MPSModule_getDefaultMPSGenerator,
|
|
||||||
METH_NOARGS,
|
|
||||||
nullptr},
|
|
||||||
{nullptr}};
|
|
||||||
|
|
||||||
PyMethodDef* python_functions() {
|
|
||||||
return _MPSModule_methods;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mps
|
|
||||||
} // namespace torch
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/csrc/python_headers.h>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
namespace mps {
|
|
||||||
|
|
||||||
PyMethodDef* python_functions();
|
|
||||||
|
|
||||||
} // namespace mps
|
|
||||||
} // namespace torch
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
r"""
|
|
||||||
This package enables an interface for accessing MPS backend in python
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from .. import Tensor
|
|
||||||
|
|
||||||
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
# local helper function (not public or exported)
|
|
||||||
def _get_default_mps_generator() -> torch._C.Generator:
|
|
||||||
global _default_mps_generator
|
|
||||||
if _default_mps_generator is None:
|
|
||||||
_default_mps_generator = torch._C._mps_get_default_generator()
|
|
||||||
return _default_mps_generator
|
|
||||||
|
|
||||||
def synchronize() -> None:
|
|
||||||
r"""Waits for all kernels in all streams on a MPS device to complete."""
|
|
||||||
return torch._C._mps_synchronize()
|
|
||||||
|
|
||||||
def get_rng_state() -> Tensor:
|
|
||||||
r"""Returns the random number generator state as a ByteTensor."""
|
|
||||||
return _get_default_mps_generator().get_state()
|
|
||||||
|
|
||||||
def set_rng_state(new_state: Tensor) -> None:
|
|
||||||
r"""Sets the random number generator state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_state (torch.ByteTensor): The desired state
|
|
||||||
"""
|
|
||||||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
|
||||||
_get_default_mps_generator().set_state(new_state_copy)
|
|
||||||
|
|
||||||
def manual_seed(seed: int) -> None:
|
|
||||||
r"""Sets the seed for generating random numbers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): The desired seed.
|
|
||||||
"""
|
|
||||||
# the torch.mps.manual_seed() can be called from the global
|
|
||||||
# torch.manual_seed() in torch/random.py. So we need to make
|
|
||||||
# sure mps is available (otherwise we just return without
|
|
||||||
# erroring out)
|
|
||||||
if not torch._C._is_mps_available():
|
|
||||||
return
|
|
||||||
seed = int(seed)
|
|
||||||
_get_default_mps_generator().manual_seed(seed)
|
|
||||||
|
|
||||||
def seed() -> None:
|
|
||||||
r"""Sets the seed for generating random numbers to a random number."""
|
|
||||||
_get_default_mps_generator().seed()
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']
|
|
||||||
|
|
@ -39,9 +39,6 @@ def manual_seed(seed) -> torch._C.Generator:
|
||||||
if not torch.cuda._is_in_bad_fork():
|
if not torch.cuda._is_in_bad_fork():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
import torch.mps
|
|
||||||
torch.mps.manual_seed(seed)
|
|
||||||
|
|
||||||
return default_generator.manual_seed(seed)
|
return default_generator.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -55,9 +52,6 @@ def seed() -> int:
|
||||||
if not torch.cuda._is_in_bad_fork():
|
if not torch.cuda._is_in_bad_fork():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
import torch.mps
|
|
||||||
torch.mps.manual_seed(seed)
|
|
||||||
|
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user