mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user-streams] Make device-agnostic streams weakref compatible (#164304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164304 Approved by: https://github.com/williamwen42, https://github.com/colesbury ghstack dependencies: #162903, #164343, #164344, #164507, #162901
This commit is contained in:
parent
c5701d0ab5
commit
bfc2050db9
|
|
@ -1,4 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
import weakref
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
|
@ -15,6 +16,10 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
|
|
||||||
|
def test_stream_weakref(self):
|
||||||
|
s = torch.Stream()
|
||||||
|
weakref.ref(s)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_run_opcheck(self):
|
def test_run_opcheck(self):
|
||||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
|
||||||
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
||||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||||
self->context = nullptr;
|
self->context = nullptr;
|
||||||
|
self->weakreflist = nullptr;
|
||||||
|
|
||||||
return static_cast<PyObject*>(ptr.release());
|
return static_cast<PyObject*>(ptr.release());
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|
@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
||||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||||
self->context = nullptr;
|
self->context = nullptr;
|
||||||
|
self->weakreflist = nullptr;
|
||||||
return ptr.release();
|
return ptr.release();
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static void THPStream_dealloc(THPStream* self) {
|
static void THPStream_dealloc(THPStream* self) {
|
||||||
|
PyObject_ClearWeakRefs((PyObject*)self);
|
||||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = {
|
||||||
nullptr, /* tp_traverse */
|
nullptr, /* tp_traverse */
|
||||||
nullptr, /* tp_clear */
|
nullptr, /* tp_clear */
|
||||||
THPStream_richcompare, /* tp_richcompare */
|
THPStream_richcompare, /* tp_richcompare */
|
||||||
0, /* tp_weaklistoffset */
|
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
|
||||||
nullptr, /* tp_iter */
|
nullptr, /* tp_iter */
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
// NOLINTNEXTLINE(*const-cast)
|
// NOLINTNEXTLINE(*const-cast)
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ struct THPStream {
|
||||||
int64_t device_index;
|
int64_t device_index;
|
||||||
// Used to switch stream context management, initialized lazily.
|
// Used to switch stream context management, initialized lazily.
|
||||||
PyObject* context;
|
PyObject* context;
|
||||||
|
PyObject* weakreflist;
|
||||||
};
|
};
|
||||||
extern TORCH_API PyTypeObject* THPStreamClass;
|
extern TORCH_API PyTypeObject* THPStreamClass;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user