[user-streams] Make device-agnostic streams weakref compatible (#164304)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164304
Approved by: https://github.com/williamwen42, https://github.com/colesbury
ghstack dependencies: #162903, #164343, #164344, #164507, #162901
This commit is contained in:
Michael Lazos 2025-10-28 21:36:15 -07:00 committed by PyTorch MergeBot
parent c5701d0ab5
commit bfc2050db9
3 changed files with 10 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import weakref
import torch
import torch._dynamo.test_case
@ -15,6 +16,10 @@ class TestStreams(torch._dynamo.test_case.TestCase):
def tearDownClass(cls):
super().tearDownClass()
def test_stream_weakref(self):
s = torch.Stream()
weakref.ref(s)
@requires_cuda
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream, join_stream

View File

@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return static_cast<PyObject*>(ptr.release());
END_HANDLE_TH_ERRORS
@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return ptr.release();
END_HANDLE_TH_ERRORS
}
static void THPStream_dealloc(THPStream* self) {
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
THPStream_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
// NOLINTNEXTLINE(*const-cast)

View File

@ -13,6 +13,7 @@ struct THPStream {
int64_t device_index;
// Used to switch stream context management, initialized lazily.
PyObject* context;
PyObject* weakreflist;
};
extern TORCH_API PyTypeObject* THPStreamClass;