mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user-streams] Make device-agnostic streams weakref compatible (#164304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164304 Approved by: https://github.com/williamwen42, https://github.com/colesbury ghstack dependencies: #162903, #164343, #164344, #164507, #162901
This commit is contained in:
parent
c5701d0ab5
commit
bfc2050db9
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
|
|
@ -15,6 +16,10 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
|||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
def test_stream_weakref(self):
|
||||
s = torch.Stream()
|
||||
weakref.ref(s)
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
|
|||
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
|
||||
return static_cast<PyObject*>(ptr.release());
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
|
@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
|||
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
return ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static void THPStream_dealloc(THPStream* self) {
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||
}
|
||||
|
||||
|
|
@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = {
|
|||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
THPStream_richcompare, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
// NOLINTNEXTLINE(*const-cast)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ struct THPStream {
|
|||
int64_t device_index;
|
||||
// Used to switch stream context management, initialized lazily.
|
||||
PyObject* context;
|
||||
PyObject* weakreflist;
|
||||
};
|
||||
extern TORCH_API PyTypeObject* THPStreamClass;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user