[OpenReg] Add _lazy_init and rng_state support for OpenReg (#151914)

As the title stated.

**Changes**:
- Add get_rng_state & set_rng_state support for OpenReg
- Add _lazy_init support for OpenReg
- Remove redundant code for cuda/Module.cpp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151914
Approved by: https://github.com/albanD
This commit is contained in:
FFFrog 2025-05-04 12:45:07 +08:00 committed by PyTorch MergeBot
parent c8bac51ec1
commit fd8fd01d25
4 changed files with 78 additions and 12 deletions

View File

@ -44,19 +44,47 @@ def _create_module():
return torch.accelerator.current_device_index() return torch.accelerator.current_device_index()
def get_rng_state(device): def get_rng_state(device):
return torch.empty(4, 4, device="openreg") if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
return default_generator.get_state()
def set_rng_state(new_state, device): def set_rng_state(new_state, device):
pass if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
default_generator.set_state(new_state)
def is_initialized():
return module._initialized
def _lazy_init():
if is_initialized():
return
pytorch_openreg._C._init()
module._initialized = True
module.is_available = is_available # type: ignore[assignment]
module._initialized = False # type: ignore[assignment]
module._lazy_init = _lazy_init # type: ignore[assignment]
module.is_initialized = is_initialized # type: ignore[assignment]
module.device = device # type: ignore[assignment] module.device = device # type: ignore[assignment]
module.device_count = device_count # type: ignore[assignment] module.device_count = device_count # type: ignore[assignment]
module.is_available = is_available # type: ignore[assignment]
module.current_device = current_device # type: ignore[assignment] module.current_device = current_device # type: ignore[assignment]
module.get_rng_state = get_rng_state # type: ignore[assignment] module.get_rng_state = get_rng_state # type: ignore[assignment]
module.set_rng_state = set_rng_state # type: ignore[assignment] module.set_rng_state = set_rng_state # type: ignore[assignment]
module._lazy_init = lambda: None # type: ignore[assignment]
module.is_initialized = lambda: True # type: ignore[assignment]
return module return module

View File

@ -1,12 +1,45 @@
#include "OpenReg.h" #include "OpenReg.h"
static struct PyModuleDef openreg_C_module = { #include <ATen/Context.h>
PyModuleDef_HEAD_INIT,
"pytorch_openreg._C", #include <torch/csrc/Exceptions.h>
nullptr, #include <torch/csrc/utils.h>
-1, #include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_numbers.h>
static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPUtils_checkLong(arg),
"_get_default_generator expects an int, but got ",
THPUtils_typename(arg));
auto idx = static_cast<int>(THPUtils_unpackLong(arg));
return THPGenerator_initDefaultGenerator(
at::globalContext().defaultGenerator(
c10::Device(c10::DeviceType::PrivateUse1, idx)));
END_HANDLE_TH_ERRORS
}
static PyMethodDef methods[] = {
{"_init", _initExtension, METH_NOARGS, nullptr},
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
{nullptr, nullptr, 0, nullptr}
}; };
static struct PyModuleDef openreg_C_module =
{PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods};
PyMODINIT_FUNC PyInit__C(void) { PyMODINIT_FUNC PyInit__C(void) {
PyObject* mod = PyModule_Create(&openreg_C_module); PyObject* mod = PyModule_Create(&openreg_C_module);

View File

@ -77,6 +77,11 @@ class TestOpenReg(TestCase):
self.assertEqual(generator.device.type, "openreg") self.assertEqual(generator.device.type, "openreg")
self.assertEqual(generator.device.index, 1) self.assertEqual(generator.device.index, 1)
# TODO(FFFrog): Add more check for rng_state
def test_rng_state(self):
state = torch.openreg.get_rng_state(0)
torch.openreg.set_rng_state(state, 0)
@skipIfTorchDynamo("unsupported aten.is_pinned.default") @skipIfTorchDynamo("unsupported aten.is_pinned.default")
def test_pin_memory(self): def test_pin_memory(self):
cpu_a = torch.randn(10) cpu_a = torch.randn(10)

View File

@ -1522,10 +1522,10 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
auto num_gpus = c10::cuda::device_count(); auto num_gpus = c10::cuda::device_count();
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus)); auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
for (const auto i : c10::irange(num_gpus)) { for (const auto i : c10::irange(num_gpus)) {
auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator( auto cast_gen = THPGenerator_initDefaultGenerator(
at::cuda::detail::getDefaultCUDAGenerator(i)); at::cuda::detail::getDefaultCUDAGenerator(i));
// This reference is meant to be given away, so no need to incref here. // This reference is meant to be given away, so no need to incref here.
PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen); PyTuple_SetItem(default_cuda_generators, i, cast_gen);
} }
set_module_attr("default_generators", default_cuda_generators); set_module_attr("default_generators", default_cuda_generators);
bindGetDeviceProperties(m); bindGetDeviceProperties(m);