[User-streams] Make torch.Event weakref compatible (#164522)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164522
Approved by: https://github.com/williamwen42
ghstack dependencies: #164304
This commit is contained in:
Michael Lazos 2025-10-29 11:57:18 -07:00 committed by PyTorch MergeBot
parent c3047938a0
commit c54e2c5b41
4 changed files with 13 additions and 23 deletions

View File

@ -234,27 +234,6 @@ class InPlaceCompilationTests(TestCase):
with self.assertRaises(IndexError):
fn(torch.randn(10), 99)
def test_list_bad_weakref(self):
import weakref
a = torch.Event()
with self.assertRaises(TypeError):
weakref.ref(a)
@torch.compile(backend="eager")
class Mod(torch.nn.Module):
def __init__(self, event):
super().__init__()
self.event = event
def forward(self, x):
return x * int(self.event.query())
e = torch.Event()
m = Mod(e)
a = torch.randn(10)
self.assertEqual(m(a), a)
# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good

View File

@ -21,6 +21,11 @@ class TestStreams(torch._dynamo.test_case.TestCase):
s = torch.Stream()
weakref.ref(s)
@requires_cuda
def test_event_weakref(self):
e = torch.Event()
weakref.ref(e)
@requires_cuda
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream, join_stream

View File

@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
}
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
self->weakreflist = nullptr;
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
TORCH_CHECK(self, "Failed to allocate memory for Event");
auto self_ = reinterpret_cast<THPEvent*>(self.get());
self_->weakreflist = nullptr;
new (&self_->event) c10::Event(device_type, flag);
return self.release();
}
@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) {
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@ -282,7 +285,8 @@ static PyMethodDef THPEvent_methods[] = {
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{nullptr}};
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
PyTypeObject THPEventType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.Event", /* tp_name */
@ -308,7 +312,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEvent_methods, /* tp_methods */
@ -323,6 +327,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_alloc */
THPEvent_pynew, /* tp_new */
};
#pragma GCC diagnostic pop
void THPEvent_init(PyObject* module) {
THPEventClass = &THPEventType;

View File

@ -7,6 +7,7 @@
struct TORCH_API THPEvent {
PyObject_HEAD
c10::Event event;
PyObject* weakreflist;
};
TORCH_API extern PyTypeObject* THPEventClass;
TORCH_API extern PyTypeObject THPEventType;