mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[User-streams] Make torch.Event weakref compatible (#164522)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164522 Approved by: https://github.com/williamwen42 ghstack dependencies: #162903, #164343, #164344, #164507, #162901, #164304
This commit is contained in:
parent
bfc2050db9
commit
cde81e92b9
|
|
@ -234,27 +234,6 @@ class InPlaceCompilationTests(TestCase):
|
||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
fn(torch.randn(10), 99)
|
fn(torch.randn(10), 99)
|
||||||
|
|
||||||
def test_list_bad_weakref(self):
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
a = torch.Event()
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
weakref.ref(a)
|
|
||||||
|
|
||||||
@torch.compile(backend="eager")
|
|
||||||
class Mod(torch.nn.Module):
|
|
||||||
def __init__(self, event):
|
|
||||||
super().__init__()
|
|
||||||
self.event = event
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x * int(self.event.query())
|
|
||||||
|
|
||||||
e = torch.Event()
|
|
||||||
m = Mod(e)
|
|
||||||
a = torch.randn(10)
|
|
||||||
self.assertEqual(m(a), a)
|
|
||||||
|
|
||||||
|
|
||||||
# The private variants of the below functions are extensively tested
|
# The private variants of the below functions are extensively tested
|
||||||
# So as long as the signatures match we're good
|
# So as long as the signatures match we're good
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,10 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||||
s = torch.Stream()
|
s = torch.Stream()
|
||||||
weakref.ref(s)
|
weakref.ref(s)
|
||||||
|
|
||||||
|
def test_event_weakref(self):
|
||||||
|
e = torch.Event()
|
||||||
|
weakref.ref(e)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_run_opcheck(self):
|
def test_run_opcheck(self):
|
||||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
|
||||||
}
|
}
|
||||||
|
|
||||||
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
|
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
|
||||||
|
self->weakreflist = nullptr;
|
||||||
|
|
||||||
// TODO: blocking and interprocess are not supported yet. To support them, the
|
// TODO: blocking and interprocess are not supported yet. To support them, the
|
||||||
// flag system of c10::Event needs to be refactored. C10::Event should also
|
// flag system of c10::Event needs to be refactored. C10::Event should also
|
||||||
|
|
@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
|
||||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||||
TORCH_CHECK(self, "Failed to allocate memory for Event");
|
TORCH_CHECK(self, "Failed to allocate memory for Event");
|
||||||
auto self_ = reinterpret_cast<THPEvent*>(self.get());
|
auto self_ = reinterpret_cast<THPEvent*>(self.get());
|
||||||
|
self_->weakreflist = nullptr;
|
||||||
new (&self_->event) c10::Event(device_type, flag);
|
new (&self_->event) c10::Event(device_type, flag);
|
||||||
return self.release();
|
return self.release();
|
||||||
}
|
}
|
||||||
|
|
@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) {
|
||||||
pybind11::gil_scoped_release no_gil{};
|
pybind11::gil_scoped_release no_gil{};
|
||||||
self->event.~Event();
|
self->event.~Event();
|
||||||
}
|
}
|
||||||
|
PyObject_ClearWeakRefs((PyObject*)self);
|
||||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -282,7 +285,8 @@ static PyMethodDef THPEvent_methods[] = {
|
||||||
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
|
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
|
||||||
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
|
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
|
||||||
{nullptr}};
|
{nullptr}};
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
|
||||||
PyTypeObject THPEventType = {
|
PyTypeObject THPEventType = {
|
||||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||||
"torch.Event", /* tp_name */
|
"torch.Event", /* tp_name */
|
||||||
|
|
@ -308,7 +312,7 @@ PyTypeObject THPEventType = {
|
||||||
nullptr, /* tp_traverse */
|
nullptr, /* tp_traverse */
|
||||||
nullptr, /* tp_clear */
|
nullptr, /* tp_clear */
|
||||||
nullptr, /* tp_richcompare */
|
nullptr, /* tp_richcompare */
|
||||||
0, /* tp_weaklistoffset */
|
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
|
||||||
nullptr, /* tp_iter */
|
nullptr, /* tp_iter */
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
THPEvent_methods, /* tp_methods */
|
THPEvent_methods, /* tp_methods */
|
||||||
|
|
@ -323,6 +327,7 @@ PyTypeObject THPEventType = {
|
||||||
nullptr, /* tp_alloc */
|
nullptr, /* tp_alloc */
|
||||||
THPEvent_pynew, /* tp_new */
|
THPEvent_pynew, /* tp_new */
|
||||||
};
|
};
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
void THPEvent_init(PyObject* module) {
|
void THPEvent_init(PyObject* module) {
|
||||||
THPEventClass = &THPEventType;
|
THPEventClass = &THPEventType;
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
struct TORCH_API THPEvent {
|
struct TORCH_API THPEvent {
|
||||||
PyObject_HEAD
|
PyObject_HEAD
|
||||||
c10::Event event;
|
c10::Event event;
|
||||||
|
PyObject* weakreflist;
|
||||||
};
|
};
|
||||||
TORCH_API extern PyTypeObject* THPEventClass;
|
TORCH_API extern PyTypeObject* THPEventClass;
|
||||||
TORCH_API extern PyTypeObject THPEventType;
|
TORCH_API extern PyTypeObject THPEventType;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user