pytorch/torch/csrc/cuda/Stream.h
chengjun 5741de883a Define the record_stream method in native_functions.yaml (#44301)
Summary:
The record_stream method was hard coded for CUDA device. Define the record_stream in the native_functions.yaml to enable the dynamic dispatch to different end device.

Fixes https://github.com/pytorch/pytorch/issues/36556

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44301

Reviewed By: glaringlee

Differential Revision: D23763954

Pulled By: ezyang

fbshipit-source-id: e6d24f5e7892b56101fa858a6cad2abc5cdc4293
2020-10-13 09:15:22 -07:00

21 lines
466 B
C

#ifndef THCP_STREAM_INC
#define THCP_STREAM_INC
#include <torch/csrc/Stream.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/python_headers.h>
#include <THC/THC.h>
struct THCPStream : THPStream{
at::cuda::CUDAStream cuda_stream;
};
extern PyObject *THCPStreamClass;
void THCPStream_init(PyObject *module);
inline bool THCPStream_Check(PyObject* obj) {
return THCPStreamClass && PyObject_IsInstance(obj, THCPStreamClass);
}
#endif // THCP_STREAM_INC