Move CUDAStreamInternals inside detail namespace. (#14109)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14109

Previously it was at the top level, because the author was under
the impression that you could only refer to top-level C++ names
from C, but this is not true; you just need to make a stub struct
conditioned on __cplusplus.

Reviewed By: smessmer

Differential Revision: D13104694

fbshipit-source-id: ecb7ae6dcfa4ab4e062aad7a886937dca15fd1b2
This commit is contained in:
Edward Yang 2018-11-19 17:01:34 -08:00 committed by Facebook Github Bot
parent e58bbbac18
commit 50b914aeeb
12 changed files with 61 additions and 52 deletions

View File

@ -20,21 +20,21 @@ cudaDeviceProp* getDeviceProperties(int64_t device) {
CUDAStream getStreamFromPool(
const bool isHighPriority
, int64_t device) {
return CUDAStream(detail::CUDAStream_getStreamFromPool(isHighPriority, device));
return CUDAStream(impl::CUDAStream_getStreamFromPool(isHighPriority, device));
}
CUDAStream getDefaultCUDAStream(int64_t device) {
return CUDAStream(detail::CUDAStream_getDefaultStream(device));
return CUDAStream(impl::CUDAStream_getDefaultStream(device));
}
CUDAStream getCurrentCUDAStream(int64_t device) {
return CUDAStream(detail::CUDAStream_getCurrentStream(device));
return CUDAStream(impl::CUDAStream_getCurrentStream(device));
}
void setCurrentCUDAStream(CUDAStream stream) {
detail::CUDAStream_setStream(stream.internals());
impl::CUDAStream_setStream(stream.internals());
}
void uncheckedSetCurrentCUDAStream(CUDAStream stream) {
detail::CUDAStream_uncheckedSetStream(stream.internals());
impl::CUDAStream_uncheckedSetStream(stream.internals());
}
Allocator* getCUDADeviceAllocator() {

View File

@ -56,7 +56,7 @@ struct CUDAGuard {
private:
/// The guard for the current device.
c10::impl::InlineDeviceGuard<at::cuda::detail::CUDAGuardImpl> guard_;
c10::impl::InlineDeviceGuard<at::cuda::impl::CUDAGuardImpl> guard_;
};
/// A variant of OptionalDeviceGuard that is specialized for CUDA. See
@ -108,7 +108,7 @@ struct OptionalCUDAGuard {
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalDeviceGuard<at::cuda::detail::CUDAGuardImpl> guard_;
c10::impl::InlineOptionalDeviceGuard<at::cuda::impl::CUDAGuardImpl> guard_;
};
/// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard
@ -165,7 +165,7 @@ struct CUDAStreamGuard {
Device original_device() const { return guard_.original_device(); }
private:
c10::impl::InlineStreamGuard<at::cuda::detail::CUDAGuardImpl> guard_;
c10::impl::InlineStreamGuard<at::cuda::impl::CUDAGuardImpl> guard_;
};
/// A variant of OptionalStreamGuard that is specialized for CUDA. See CUDAGuard
@ -228,7 +228,7 @@ struct OptionalCUDAStreamGuard {
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalStreamGuard<at::cuda::detail::CUDAGuardImpl> guard_;
c10::impl::InlineOptionalStreamGuard<at::cuda::impl::CUDAGuardImpl> guard_;
};
} // namespace cuda

View File

@ -10,6 +10,11 @@
#include <vector>
#include <array>
namespace at {
namespace cuda {
namespace impl {
// Internal implementation is entirely hidden
// Note: CUDAStreamInternals doubles for a THCStream
struct CUDAStreamInternals {
@ -24,11 +29,6 @@ struct CUDAStreamInternals {
cudaStream_t stream = nullptr;
};
namespace at {
namespace cuda {
namespace detail {
// Global stream state and constants
static int64_t num_gpus = -1;
static constexpr int kStreamsPerPoolBits = 5;
@ -298,26 +298,26 @@ int64_t CUDAStream_device(const CUDAStreamInternals* ptr) {
return ptr->device;
}
} // namespace detail
} // namespace impl
CUDAStream::CUDAStream(const CUDAStreamInternals* ptr)
: stream_(c10::Device(DeviceType::CUDA, detail::CUDAStream_device(ptr)), detail::CUDAStream_getStreamId(ptr)) {
CUDAStream::CUDAStream(const impl::CUDAStreamInternals* ptr)
: stream_(c10::Device(DeviceType::CUDA, impl::CUDAStream_device(ptr)), impl::CUDAStream_getStreamId(ptr)) {
}
// See Note [StreamId assignment]
CUDAStreamInternals* CUDAStream::internals() const {
impl::CUDAStreamInternals* CUDAStream::internals() const {
c10::DeviceIndex device_index = stream_.device_index();
detail::StreamIdType st = detail::streamIdType(stream_.id());
size_t si = detail::streamIdIndex(stream_.id());
impl::StreamIdType st = impl::streamIdType(stream_.id());
size_t si = impl::streamIdIndex(stream_.id());
switch (st) {
case detail::StreamIdType::DEFAULT:
case impl::StreamIdType::DEFAULT:
AT_ASSERTM(si == 0, "Unrecognized stream ", stream_,
" (I think this should be the default stream, but I got a non-zero index ", si, ")");
return &detail::default_streams[device_index];
case detail::StreamIdType::LOW:
return &detail::low_priority_streams[device_index][si];
case detail::StreamIdType::HIGH:
return &detail::high_priority_streams[device_index][si];
return &impl::default_streams[device_index];
case impl::StreamIdType::LOW:
return &impl::low_priority_streams[device_index][si];
case impl::StreamIdType::HIGH:
return &impl::high_priority_streams[device_index][si];
default:
AT_ASSERTM(0, "Unrecognized stream ", stream_, " (I didn't recognize the stream type, ", st, ")");
}

View File

@ -51,12 +51,12 @@
* overlap the performance critical streams.
*/
struct CUDAStreamInternals;
namespace at {
namespace cuda {
namespace detail {
namespace impl {
struct CUDAStreamInternals;
// Pointer-based API (for internal use, backwards compatibility with C-based API)
AT_CUDA_API CUDAStreamInternals* CUDAStream_getDefaultStream(int64_t device = -1);
@ -73,7 +73,7 @@ AT_CUDA_API void CUDAStream_uncheckedSetStream(CUDAStreamInternals* internals);
AT_CUDA_API cudaStream_t CUDAStream_stream(const CUDAStreamInternals*);
AT_CUDA_API int64_t CUDAStream_device(const CUDAStreamInternals*);
} // namespace detail
} // namespace impl
// RAII for a CUDA stream
// Allows use as a cudaStream_t, copying, moving, and metadata access.
@ -81,7 +81,7 @@ struct AT_CUDA_API CUDAStream {
enum Unchecked { UNCHECKED };
explicit CUDAStream(const CUDAStreamInternals*);
explicit CUDAStream(const impl::CUDAStreamInternals*);
explicit CUDAStream(Stream stream) : stream_(stream) {
AT_CHECK(stream_.device_type() == DeviceType::CUDA);
@ -96,8 +96,8 @@ struct AT_CUDA_API CUDAStream {
// Getters
int64_t device_index() const { return stream_.device_index(); }
Device device() const { return Device(DeviceType::CUDA, device_index()); }
cudaStream_t stream() const { return detail::CUDAStream_stream(internals()); }
CUDAStreamInternals* internals() const;
cudaStream_t stream() const { return impl::CUDAStream_stream(internals()); }
impl::CUDAStreamInternals* internals() const;
Stream unwrap() const { return stream_; }

View File

@ -2,7 +2,7 @@
namespace at {
namespace cuda {
namespace detail {
namespace impl {
constexpr DeviceType CUDAGuardImpl::static_type;

View File

@ -10,7 +10,7 @@
namespace at {
namespace cuda {
namespace detail {
namespace impl {
struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::CUDA;
@ -52,4 +52,4 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
}
};
}}} // namespace at::cuda::detail
}}} // namespace at::cuda::impl

View File

@ -176,7 +176,7 @@ TEST(TestStream, CUDAGuardTest) {
TEST(TestStream, StreamPoolTest) {
std::vector<at::cuda::CUDAStream> streams{};
for (int i = 0; i < 200; ++i) {
streams.emplace_back(at::cuda::detail::CUDAStream_getStreamFromPool());
streams.emplace_back(at::cuda::impl::CUDAStream_getStreamFromPool());
}
std::unordered_set<cudaStream_t> stream_set{};

View File

@ -223,29 +223,29 @@ THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr(
}
THCStream* THCState_getStreamOnDevice(THCState* state, int device) {
return at::cuda::detail::CUDAStream_getCurrentStream(device);
return at::cuda::impl::CUDAStream_getCurrentStream(device);
}
void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream) {
at::cuda::detail::CUDAStream_setStream(stream);
at::cuda::impl::CUDAStream_setStream(stream);
}
THC_API void THCState_setStream(THCState *state, THCStream* stream) {
at::cuda::detail::CUDAStream_setStream(stream);
at::cuda::impl::CUDAStream_setStream(stream);
}
cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device) {
return at::cuda::detail::CUDAStream_stream(
at::cuda::detail::CUDAStream_getCurrentStream(device));
return at::cuda::impl::CUDAStream_stream(
at::cuda::impl::CUDAStream_getCurrentStream(device));
}
cudaStream_t THCState_getCurrentStream(THCState *state) {
return at::cuda::detail::CUDAStream_stream(
at::cuda::detail::CUDAStream_getCurrentStream());
return at::cuda::impl::CUDAStream_stream(
at::cuda::impl::CUDAStream_getCurrentStream());
}
THCStream* THCState_getStream(THCState *state) {
return at::cuda::detail::CUDAStream_getCurrentStream();
return at::cuda::impl::CUDAStream_getCurrentStream();
}
cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)

View File

@ -8,9 +8,14 @@
#undef log2
#undef expm1
#ifdef __cplusplus
#include <ATen/cuda/CUDAStream.h>
#endif
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include "cusparse.h"
#cmakedefine USE_MAGMA
@ -44,7 +49,12 @@
#endif
struct THCRNGState; /* Random number generator state. */
typedef struct CUDAStreamInternals THCStream;
#ifdef __cplusplus
typedef at::cuda::impl::CUDAStreamInternals THCStream;
#else
typdef struct at_cuda_impl_CUDAStreamInternals at_cuda_impl_CUDAStreamInternals;
typedef at_cuda_impl_CUDAStreamInternals THCStream;
#endif
typedef struct THCState THCState;
struct THCState;

View File

@ -2,19 +2,19 @@
#include "ATen/cuda/CUDAStream.h"
THC_API THCStream* THCStream_defaultStream(int device) {
return at::cuda::detail::CUDAStream_getDefaultStream(device);
return at::cuda::impl::CUDAStream_getDefaultStream(device);
}
THC_API THCStream* THCStream_new() {
return at::cuda::detail::CUDAStream_getStreamFromPool();
return at::cuda::impl::CUDAStream_getStreamFromPool();
}
THC_API cudaStream_t THCStream_stream(THCStream* stream) {
return at::cuda::detail::CUDAStream_stream(stream);
return at::cuda::impl::CUDAStream_stream(stream);
}
THC_API int THCStream_device(THCStream* stream) {
return at::cuda::detail::CUDAStream_device(stream);
return at::cuda::impl::CUDAStream_device(stream);
}
THC_API void THCStream_retain(THCStream* stream) { }

View File

@ -8,7 +8,6 @@
*
* Stream usage should be done through ATen/cuda/CUDAContext.h.
*/
typedef struct CUDAStreamInternals THCStream;
// Stream creation
THC_API THCStream* THCStream_defaultStream(int device);

View File

@ -36,7 +36,7 @@ static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject
stream = (THCStream*) cdata;
} else {
const bool isHighPriority = priority < 0 ? true : false;
stream = at::cuda::detail::CUDAStream_getStreamFromPool(isHighPriority);
stream = at::cuda::impl::CUDAStream_getStreamFromPool(isHighPriority);
}
THCPStream* self = (THCPStream *)ptr.get();