mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[OpenReg] Add _lazy_init and rng_state support for OpenReg (#151914)
As the title stated. **Changes**: - Add get_rng_state & set_rng_state support for OpenReg - Add _lazy_init support for OpenReg - Remove redundant code for cuda/Module.cpp Pull Request resolved: https://github.com/pytorch/pytorch/pull/151914 Approved by: https://github.com/albanD
This commit is contained in:
parent
c8bac51ec1
commit
fd8fd01d25
|
|
@ -44,19 +44,47 @@ def _create_module():
|
|||
return torch.accelerator.current_device_index()
|
||||
|
||||
def get_rng_state(device):
|
||||
return torch.empty(4, 4, device="openreg")
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("openreg", device)
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = pytorch_openreg._C._get_default_generator(idx)
|
||||
return default_generator.get_state()
|
||||
|
||||
def set_rng_state(new_state, device):
|
||||
pass
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device("openreg", device)
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = pytorch_openreg._C._get_default_generator(idx)
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
def is_initialized():
|
||||
return module._initialized
|
||||
|
||||
def _lazy_init():
|
||||
if is_initialized():
|
||||
return
|
||||
pytorch_openreg._C._init()
|
||||
module._initialized = True
|
||||
|
||||
module.is_available = is_available # type: ignore[assignment]
|
||||
|
||||
module._initialized = False # type: ignore[assignment]
|
||||
module._lazy_init = _lazy_init # type: ignore[assignment]
|
||||
module.is_initialized = is_initialized # type: ignore[assignment]
|
||||
|
||||
module.device = device # type: ignore[assignment]
|
||||
module.device_count = device_count # type: ignore[assignment]
|
||||
module.is_available = is_available # type: ignore[assignment]
|
||||
module.current_device = current_device # type: ignore[assignment]
|
||||
module.get_rng_state = get_rng_state # type: ignore[assignment]
|
||||
module.set_rng_state = set_rng_state # type: ignore[assignment]
|
||||
module._lazy_init = lambda: None # type: ignore[assignment]
|
||||
module.is_initialized = lambda: True # type: ignore[assignment]
|
||||
|
||||
return module
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,45 @@
|
|||
#include "OpenReg.h"
|
||||
|
||||
static struct PyModuleDef openreg_C_module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"pytorch_openreg._C",
|
||||
nullptr,
|
||||
-1,
|
||||
#include <ATen/Context.h>
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/utils.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
|
||||
static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(arg),
|
||||
"_get_default_generator expects an int, but got ",
|
||||
THPUtils_typename(arg));
|
||||
auto idx = static_cast<int>(THPUtils_unpackLong(arg));
|
||||
|
||||
return THPGenerator_initDefaultGenerator(
|
||||
at::globalContext().defaultGenerator(
|
||||
c10::Device(c10::DeviceType::PrivateUse1, idx)));
|
||||
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyMethodDef methods[] = {
|
||||
{"_init", _initExtension, METH_NOARGS, nullptr},
|
||||
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
};
|
||||
|
||||
static struct PyModuleDef openreg_C_module =
|
||||
{PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit__C(void) {
|
||||
PyObject* mod = PyModule_Create(&openreg_C_module);
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,11 @@ class TestOpenReg(TestCase):
|
|||
self.assertEqual(generator.device.type, "openreg")
|
||||
self.assertEqual(generator.device.index, 1)
|
||||
|
||||
# TODO(FFFrog): Add more check for rng_state
|
||||
def test_rng_state(self):
|
||||
state = torch.openreg.get_rng_state(0)
|
||||
torch.openreg.set_rng_state(state, 0)
|
||||
|
||||
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
|
||||
def test_pin_memory(self):
|
||||
cpu_a = torch.randn(10)
|
||||
|
|
|
|||
|
|
@ -1522,10 +1522,10 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
|
|||
auto num_gpus = c10::cuda::device_count();
|
||||
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
|
||||
for (const auto i : c10::irange(num_gpus)) {
|
||||
auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(
|
||||
auto cast_gen = THPGenerator_initDefaultGenerator(
|
||||
at::cuda::detail::getDefaultCUDAGenerator(i));
|
||||
// This reference is meant to be given away, so no need to incref here.
|
||||
PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen);
|
||||
PyTuple_SetItem(default_cuda_generators, i, cast_gen);
|
||||
}
|
||||
set_module_attr("default_generators", default_cuda_generators);
|
||||
bindGetDeviceProperties(m);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user