diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index 7df0ba2f1d3..1f7290c51dd 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -234,27 +234,6 @@ class InPlaceCompilationTests(TestCase): with self.assertRaises(IndexError): fn(torch.randn(10), 99) - def test_list_bad_weakref(self): - import weakref - - a = torch.Event() - with self.assertRaises(TypeError): - weakref.ref(a) - - @torch.compile(backend="eager") - class Mod(torch.nn.Module): - def __init__(self, event): - super().__init__() - self.event = event - - def forward(self, x): - return x * int(self.event.query()) - - e = torch.Event() - m = Mod(e) - a = torch.randn(10) - self.assertEqual(m(a), a) - # The private variants of the below functions are extensively tested # So as long as the signatures match we're good diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 694b505b0c1..807a4d71dc1 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -20,6 +20,10 @@ class TestStreams(torch._dynamo.test_case.TestCase): s = torch.Stream() weakref.ref(s) + def test_event_weakref(self): + e = torch.Event() + weakref.ref(e) + @requires_cuda def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp index 319eee8a41c..fd7d72228fc 100644 --- a/torch/csrc/Event.cpp +++ b/torch/csrc/Event.cpp @@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew( } THPEvent* self = reinterpret_cast(ptr.get()); + self->weakreflist = nullptr; // TODO: blocking and interprocess are not supported yet. To support them, the // flag system of c10::Event needs to be refactored. C10::Event should also @@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) { auto self = THPObjectPtr{type->tp_alloc(type, 0)}; TORCH_CHECK(self, "Failed to allocate memory for Event"); auto self_ = reinterpret_cast(self.get()); + self_->weakreflist = nullptr; new (&self_->event) c10::Event(device_type, flag); return self.release(); } @@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) { pybind11::gil_scoped_release no_gil{}; self->event.~Event(); } + PyObject_ClearWeakRefs((PyObject*)self); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -282,7 +285,8 @@ static PyMethodDef THPEvent_methods[] = { {"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr}, {"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr}, {nullptr}}; - +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Winvalid-offsetof" PyTypeObject THPEventType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */ @@ -308,7 +312,7 @@ PyTypeObject THPEventType = { nullptr, /* tp_traverse */ nullptr, /* tp_clear */ nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ + offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ THPEvent_methods, /* tp_methods */ @@ -323,6 +327,7 @@ PyTypeObject THPEventType = { nullptr, /* tp_alloc */ THPEvent_pynew, /* tp_new */ }; +#pragma GCC diagnostic pop void THPEvent_init(PyObject* module) { THPEventClass = &THPEventType; diff --git a/torch/csrc/Event.h b/torch/csrc/Event.h index 3bbc7d37939..7dfc7bb426d 100644 --- a/torch/csrc/Event.h +++ b/torch/csrc/Event.h @@ -7,6 +7,7 @@ struct TORCH_API THPEvent { PyObject_HEAD c10::Event event; + PyObject* weakreflist; }; TORCH_API extern PyTypeObject* THPEventClass; TORCH_API extern PyTypeObject THPEventType;