pytorch/torch/csrc/Stream.h
Luca Wehrstedt 24bf15fe8d Support record_stream in dispatch mode (#99529)
Summary:
Issuing a `t.record_stream(s)` call while a `TorchDispatchMode` is active was throwing because PyTorch was unable to convert a c10::Stream back to a Python object. It's now fixed.

Fixes https://github.com/pytorch/pytorch/issues/94403

Test Plan: Added a unit test

Differential Revision: D45117566

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99529
Approved by: https://github.com/albanD
2023-04-21 07:17:19 +00:00

24 lines
546 B
C

#ifndef THP_STREAM_INC
#define THP_STREAM_INC
#include <c10/core/Stream.h>
#include <c10/macros/Export.h>
#include <torch/csrc/python_headers.h>
struct THPStream {
PyObject_HEAD int64_t stream_id;
int64_t device_type;
int64_t device_index;
};
extern TORCH_API PyTypeObject* THPStreamClass;
void THPStream_init(PyObject* module);
inline bool THPStream_Check(PyObject* obj) {
return THPStreamClass && PyObject_IsInstance(obj, (PyObject*)THPStreamClass);
}
PyObject* THPStream_Wrap(const c10::Stream& stream);
#endif // THP_STREAM_INC