Ports Streams to ATen (#8997)

Summary:
This PR moves the THCStream logic (from both the THCStream and THCState APIs) to ATen. In particular, it:

+ Creates a new (THC free) at::CUDAStream class and API
+ Extends the at::Context API to expose it
+ Stubs the current THCStream and THCState APIs to use it
+ Updates THC to no longer violate stream encapsulation (stream.hpp is dead)
+ Adds an ATen cpp test of the API
+ Bonus: Removes some debug spew in test_nn.py

The new API has several advantages over the old one:

(1) It comes with an easy to use RAII, the CUDAStream. CUDAStreams have the expected copy and move semantics and are implicitly convertible to cudaStream_t.
(2) It does not depend on THCState, THCThreadLocal, or CUDA (thanks to goldsborough for suggesting the dynamic registration technique)
(3) It provides one consistent API/place for all stream operations, instead of having them split between THCStream and THCState
(4) The internals are completely encapsulated, unlike the historic THCStream
(5) It has getAndRetain semantics, which are safer than the historic gets (which allowed a gap between acquisition and retention)

There are a couple things this PR does not do, however, which are left for future work:

- It leaves the c10d:CUDAStream class as a THCStream wrapper (which now really wraps an at::CUDAStream).
- It leaves historic users of THCStream mostly untouched, except where they violated encapsulation (by using stream.hpp). A couple forward declarations were also changed.

I hope this PR allows easy usage of streams from ATen and is a useful pattern for porting more of the THCState API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8997

Differential Revision: D8683375

Pulled By: soumith

fbshipit-source-id: 2e48ad85f1f9c8817684fe63a267938e80eafdcf
This commit is contained in:
mruberry 2018-07-08 16:12:42 -07:00 committed by Facebook Github Bot
parent 75919b4e18
commit d6f21fc663
73 changed files with 585 additions and 239 deletions

View File

@ -0,0 +1,183 @@
#include "ATen/CUDAStream.h"
#include "ATen/Error.h"
#include "ATen/detail/CUDAHooksInterface.h"
#include <mutex>
// Internal implementation is entirely hidden
struct CUDAStreamInternals {
bool is_destructible;
std::atomic<int> refcount;
int64_t device; // Note: cudaGetDevice works with int32_t, not int64_t
cudaStream_t stream;
};
namespace at {
namespace detail {
/*
* Stream state
*/
static constexpr cudaStream_t DEFAULT_STREAM = 0;
static std::once_flag init_flag;
static int64_t num_gpus;
static CUDAStreamInternals* default_streams;
static thread_local CUDAStreamInternals** current_streams = nullptr;
// Creates a(n indestructible) default stream for each device
// Note: the default stream on each device is signified by a zero
// value for the pointer, and so is not actually created as usual.
// In particular, we don't need to switch devices when creating the
// streams.
static void initDefaultCUDAStreams() {
num_gpus = getCUDAHooks().getNumGPUs();
default_streams = (CUDAStreamInternals*) malloc(num_gpus * sizeof(CUDAStreamInternals));
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
default_streams[i].is_destructible = false;
default_streams[i].refcount = 0;
default_streams[i].device = i;
default_streams[i].stream = DEFAULT_STREAM;
}
}
// Init front-end to ensure initialization only occurs once
static void initCUDAStreamsOnce() {
// Inits default streams (once, globally)
std::call_once(init_flag, initDefaultCUDAStreams);
// Inits current streams (thread local) to default streams
if (current_streams) return;
current_streams = (CUDAStreamInternals**) malloc(num_gpus * sizeof(CUDAStreamInternals*));
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
current_streams[i] = &default_streams[i];
}
}
/*
* Pointer-based stream API
*/
// Helper to return the current device
static inline int64_t current_device() {
int cur_device;
DynamicCUDAInterface::get_device(&cur_device);
return cur_device;
}
// Helper to verify the GPU index is valid
static inline void check_gpu(int64_t device) {
AT_CHECK(device >= 0 && device < num_gpus);
}
CUDAStreamInternals* CUDAStream_getDefaultStreamOnDevice(int64_t device) {
initCUDAStreamsOnce();
check_gpu(device);
return &default_streams[device];
}
CUDAStreamInternals* CUDAStream_getDefaultStream() {
return CUDAStream_getDefaultStreamOnDevice(current_device());
}
// Creates (and retains) and new cuda stream
CUDAStreamInternals* CUDAStream_createAndRetainWithOptions(int32_t flags, int32_t priority) {
CUDAStreamInternals* internals = (CUDAStreamInternals*) malloc(sizeof(CUDAStreamInternals));
internals->is_destructible = true;
internals->refcount = 1;
internals->device = current_device();
DynamicCUDAInterface::cuda_stream_create_with_priority(&internals->stream, flags, priority);
return internals;
}
// Note: despite not being "unsafe," is using these methods in a multithreaded
// environment then the caller must be sure that streams are valid
// when they're requested. These methods will throw an error if an
// invalid stream is requested.
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStreamOnDevice(int64_t device) {
initCUDAStreamsOnce();
check_gpu(device);
auto cur = current_streams[device];
AT_CHECK(CUDAStream_retain(cur));
return cur;
}
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStream() {
return CUDAStream_getAndRetainCurrentStreamOnDevice(current_device());
}
// Note: these unsafe methods do not retain the stream before returning it.
// This is unsafe behavior and these methods SHOULD NOT BE USED.
// They are here only for legacy compatibility.
CUDAStreamInternals* CUDAStream_getCurrentStreamOnDeviceUnsafe(int64_t device) {
initCUDAStreamsOnce();
check_gpu(device);
return current_streams[device];
}
CUDAStreamInternals* CUDAStream_getCurrentStreamUnsafe() {
return CUDAStream_getCurrentStreamOnDeviceUnsafe(current_device());
}
void CUDAStream_setStreamOnDevice(int64_t device, CUDAStreamInternals* ptr) {
initCUDAStreamsOnce();
check_gpu(device);
AT_CHECK(ptr);
AT_CHECK(ptr->device == device);
AT_CHECK(CUDAStream_retain(ptr));
CUDAStream_free(current_streams[device]);
current_streams[device] = ptr;
}
void CUDAStream_setStream(CUDAStreamInternals* ptr) {
CUDAStream_setStreamOnDevice(current_device(), ptr);
}
// Getters
cudaStream_t CUDAStream_stream(CUDAStreamInternals* ptr) {
AT_CHECK(ptr);
return ptr->stream;
}
int64_t CUDAStream_device(CUDAStreamInternals* ptr) {
AT_CHECK(ptr);
return ptr->device;
}
// Memory management
// Note: only destructible (non-default) streams are ref counted
bool CUDAStream_retain(CUDAStreamInternals* ptr) {
AT_CHECK(ptr);
if (ptr->is_destructible) return(++ptr->refcount > 1);
return true;
}
void CUDAStream_free(CUDAStreamInternals*& ptr) {
if (ptr && ptr->stream && ptr->is_destructible && --ptr->refcount <= 0) {
AT_CHECK(ptr->refcount == 0);
DynamicCUDAInterface::cuda_stream_destroy(ptr->stream);
free(ptr);
ptr = nullptr;
}
}
} // namespace detail
/*
* CUDAStream functions
*/
// Copy constructor
CUDAStream::CUDAStream(const CUDAStream& other) {
AT_CHECK(other.internals_);
AT_CHECK(detail::CUDAStream_retain(other.internals_));
internals_ = other.internals_;
}
// Move constructor
CUDAStream::CUDAStream(CUDAStream&& other) {
AT_CHECK(other.internals_);
std::swap(internals_, other.internals_);
}
} // namespace at

View File

@ -0,0 +1,95 @@
#pragma once
#include <cstdint>
#include <utility>
/*
* A CUDA stream interface with no CUDA build dependency.
*
* Includes the CUDAStream RAII class and a pointer-based stream API.
*
* The ATen Context interface should be preferred when working with streams.
*/
// Forward-declares cudaStream_t to avoid depending on CUDA in CPU builds
// Note: this is the internal CUDA runtime typedef for cudaStream_t
struct CUstream_st;
typedef struct CUstream_st* cudaStream_t;
// Forward-declares internals
struct CUDAStreamInternals;
namespace at {
namespace detail {
// Pointer-based API (for internal use)
// Note: ATen/Context is preferred to work with streams safely
CUDAStreamInternals* CUDAStream_getDefaultStreamOnDevice(int64_t device);
CUDAStreamInternals* CUDAStream_getDefaultStream();
CUDAStreamInternals* CUDAStream_createAndRetainWithOptions(int32_t flags, int32_t priority);
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStreamOnDevice(int64_t device);
CUDAStreamInternals* CUDAStream_getAndRetainCurrentStream();
// Note: these Unsafe gets should NEVER be used and are only here for legacy
// purposes. Once those uses are gone they should be removed.
CUDAStreamInternals* CUDAStream_getCurrentStreamOnDeviceUnsafe(int64_t device);
CUDAStreamInternals* CUDAStream_getCurrentStreamUnsafe();
void CUDAStream_setStreamOnDevice(int64_t device, CUDAStreamInternals* internals);
void CUDAStream_setStream(CUDAStreamInternals* internals);
cudaStream_t CUDAStream_stream(CUDAStreamInternals*);
int64_t CUDAStream_device(CUDAStreamInternals*);
bool CUDAStream_retain(CUDAStreamInternals*);
void CUDAStream_free(CUDAStreamInternals*&);
} // namespace detail
// RAII for a CUDA stream
// Allows use as a cudaStream_t, copying, moving, and metadata access.
struct CUDAStream {
// Constants
static constexpr int32_t DEFAULT_FLAGS = 1; // = cudaStreamNonBlocking;
static constexpr int32_t DEFAULT_PRIORITY = 0;
// Constructors
CUDAStream() = default;
CUDAStream(CUDAStreamInternals* internals) : internals_{internals} { }
// Destructor
~CUDAStream() { detail::CUDAStream_free(internals_); }
// Copy constructor
CUDAStream(const CUDAStream& other);
// Move constructor
CUDAStream(CUDAStream&& other);
// Assignment operator
CUDAStream& operator=(CUDAStream other) {
std::swap(internals_, other.internals_);
return *this;
}
// Implicit conversion to cudaStream_t
operator cudaStream_t() const { return detail::CUDAStream_stream(internals_); }
// Less than operator (to allow use in sets)
friend bool operator<(const CUDAStream& left, const CUDAStream& right) {
return left.internals_ < right.internals_;
}
// Getters
int64_t device() const { return detail::CUDAStream_device(internals_); }
cudaStream_t stream() const { return detail::CUDAStream_stream(internals_); }
CUDAStreamInternals* internals() const { return internals_; }
private:
CUDAStreamInternals* internals_ = nullptr;
};
} // namespace at

View File

@ -7,6 +7,7 @@
#include "ATen/Utils.h"
#include "ATen/Error.h"
#include "ATen/detail/CUDAHooksInterface.h"
#include "ATen/CUDAStream.h"
#include <memory>
#include <mutex>
@ -78,12 +79,41 @@ public:
return thc_state.get();
}
cudaStream_t getCurrentCUDAStream() const {
return detail::getCUDAHooks().getCurrentCUDAStream(thc_state.get());
CUDAStream createCUDAStream() const {
return detail::CUDAStream_createAndRetainWithOptions(
CUDAStream::DEFAULT_FLAGS
, CUDAStream::DEFAULT_PRIORITY
);
}
cudaStream_t getCurrentCUDAStreamOnDevice(int64_t device) const {
return detail::getCUDAHooks().getCurrentCUDAStreamOnDevice(thc_state.get(), device);
CUDAStream createCUDAStreamWithOptions(int32_t flags, int32_t priority) const {
return detail::CUDAStream_createAndRetainWithOptions(flags, priority);
}
CUDAStream getDefaultCUDAStream() const {
return detail::CUDAStream_getDefaultStream();
}
CUDAStream getDefaultCUDAStreamOnDevice(int64_t device) const {
return detail::CUDAStream_getDefaultStreamOnDevice(device);
}
CUDAStream getCurrentCUDAStream() const {
return detail::CUDAStream_getAndRetainCurrentStream();
}
CUDAStream getCurrentCUDAStreamOnDevice(int64_t device) const {
return detail::CUDAStream_getAndRetainCurrentStreamOnDevice(device);
}
void setCurrentCUDAStream(CUDAStream stream) const {
return detail::CUDAStream_setStream(stream.internals());
}
void setCurrentCUDAStreamOnDevice(int64_t device, CUDAStream stream) const {
return detail::CUDAStream_setStreamOnDevice(device, stream.internals());
}
#ifndef __HIP_PLATFORM_HCC__
cusparseHandle_t getCurrentCUDASparseHandle() const {
return detail::getCUDAHooks().getCurrentCUDASparseHandle(thc_state.get());

View File

@ -49,12 +49,30 @@ void unchecked_set_device(int32_t device) {
(void)return_code;
}
void cuda_stream_create_with_priority(
cudaStream_t* pStream
, int32_t flags
, int32_t priority) {
#ifndef __HIP_PLATFORM_HCC__
check_status(cudaStreamCreateWithPriority(pStream, flags, priority));
#else
check_status(cudaStreamCreateWithFlags(pStream, flags));
#endif
}
void cuda_stream_destroy(cudaStream_t stream) {
check_status(cudaStreamDestroy(stream));
}
struct DynamicCUDAInterfaceSetter {
DynamicCUDAInterfaceSetter() {
at::detail::DynamicCUDAInterface::set_device = set_device;
at::detail::DynamicCUDAInterface::get_device = get_device;
at::detail::DynamicCUDAInterface::unchecked_set_device =
unchecked_set_device;
at::detail::DynamicCUDAInterface::cuda_stream_create_with_priority =
cuda_stream_create_with_priority;
at::detail::DynamicCUDAInterface::cuda_stream_destroy = cuda_stream_destroy;
}
};
@ -97,14 +115,6 @@ bool CUDAHooks::hasCuDNN() const {
return AT_CUDNN_ENABLED();
}
cudaStream_t CUDAHooks::getCurrentCUDAStream(THCState* thc_state) const {
return THCState_getCurrentStream(thc_state);
}
cudaStream_t CUDAHooks::getCurrentCUDAStreamOnDevice(
THCState* thc_state,
int64_t device) const {
return THCState_getCurrentStreamOnDevice(thc_state, device);
}
#ifndef __HIP_PLATFORM_HCC__
cusparseHandle_t CUDAHooks::getCurrentCUDASparseHandle(THCState* thc_state) const {
return THCState_getCurrentSparseHandle(thc_state);

View File

@ -14,11 +14,9 @@ struct CUDAHooks : public at::CUDAHooksInterface {
std::unique_ptr<Generator> initCUDAGenerator(Context*) const override;
bool hasCUDA() const override;
bool hasCuDNN() const override;
cudaStream_t getCurrentCUDAStream(THCState*) const override;
#ifndef __HIP_PLATFORM_HCC__
cusparseHandle_t getCurrentCUDASparseHandle(THCState*) const override;
#endif
cudaStream_t getCurrentCUDAStreamOnDevice(THCState*, int64_t device) const override;
struct cudaDeviceProp* getCurrentDeviceProperties(THCState*) const override;
struct cudaDeviceProp* getDeviceProperties(THCState*, int device) const override;
int64_t current_device() const override;

View File

@ -27,11 +27,28 @@ void default_unchecked_set_device(int32_t) {
"before CUDA library was loaded");
}
void default_cuda_stream_create_with_priority(cudaStream_t*, int32_t, int32_t) {
AT_ERROR(
"DynamicCUDAInterface::cuda_stream_create_with_priority called "
"before CUDA library was loaded");
}
void default_cuda_stream_destroy(cudaStream_t) {
AT_ERROR(
"DynamicCUDAInterface::cuda_stream_destroy called "
"before CUDA library was loaded");
}
// Default the static members of DynamicCUDAInterface.
void (*DynamicCUDAInterface::set_device)(int32_t) = default_set_device;
void (*DynamicCUDAInterface::get_device)(int32_t*) = default_get_device;
void (*DynamicCUDAInterface::unchecked_set_device)(int32_t) =
default_unchecked_set_device;
void (*DynamicCUDAInterface::cuda_stream_create_with_priority)(cudaStream_t*, int32_t, int32_t)
= default_cuda_stream_create_with_priority;
void (*DynamicCUDAInterface::cuda_stream_destroy)(cudaStream_t)
= default_cuda_stream_destroy;
const CUDAHooksInterface& getCUDAHooks() {
static std::unique_ptr<CUDAHooksInterface> cuda_hooks;

View File

@ -12,9 +12,9 @@
// Forward declare these CUDA types here to avoid including CUDA headers in
// ATen headers, which would make ATen always require CUDA to build.
struct THCState;
struct cudaDeviceProp;
struct CUstream_st;
typedef struct CUstream_st* cudaStream_t;
struct cudaDeviceProp;
#ifndef __HIP_PLATFORM_HCC__
// pyHIPIFY rewrites this as:
@ -89,21 +89,12 @@ struct AT_API CUDAHooksInterface {
return false;
}
virtual cudaStream_t getCurrentCUDAStream(THCState*) const {
AT_ERROR("Cannot getCurrentCUDAStream() without ATen_cuda library. ", CUDA_HELP);
}
#ifndef __HIP_PLATFORM_HCC__
virtual cusparseHandle_t getCurrentCUDASparseHandle(THCState*) const {
AT_ERROR("Cannot getCurrentCUDASparseHandle() without ATen_cuda library. ", CUDA_HELP);
}
#endif
virtual cudaStream_t getCurrentCUDAStreamOnDevice(THCState*, int64_t device)
const {
AT_ERROR("Cannot getCurrentCUDAStream() without ATen_cuda library. ", CUDA_HELP);
}
virtual struct cudaDeviceProp* getCurrentDeviceProperties(THCState*) const {
AT_ERROR("Cannot getCurrentDeviceProperties() without ATen_cuda library. ", CUDA_HELP);
}
@ -184,6 +175,8 @@ struct AT_API DynamicCUDAInterface {
static void (*set_device)(int32_t);
static void (*get_device)(int32_t*);
static void (*unchecked_set_device)(int32_t);
static void (*cuda_stream_create_with_priority)(cudaStream_t*, int32_t, int32_t);
static void (*cuda_stream_destroy)(cudaStream_t);
};
} // namespace detail
} // namespace at

View File

@ -23,7 +23,8 @@ list(APPEND ATen_CPU_TEST_SRCS
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp)
if (CUDNN_FOUND)
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp)

View File

@ -0,0 +1,103 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include "ATen/ATen.h"
#include "cuda_runtime.h"
#include <thread>
#include <functional>
/*
Tests related to ATen streams.
*/
TEST_CASE("Copying and Moving Streams", "Verifies streams are live through copying and moving") {
int32_t device = -1;
cudaStream_t cuda_stream;
// Tests that copying works as expected and preserves the stream
at::CUDAStream copyStream;
{
auto s = at::globalContext().createCUDAStream();
device = s.device();
cuda_stream = s.stream();
copyStream = s;
REQUIRE(copyStream.internals() == s.internals());
REQUIRE(copyStream.device() == device);
REQUIRE(copyStream.stream() == cuda_stream);
}
REQUIRE(copyStream.internals());
REQUIRE(copyStream.device() == device);
REQUIRE(copyStream.stream() == cuda_stream);
// Tests that moving works as expected and preserves the stream
at::CUDAStream moveStream;
{
auto s = at::globalContext().createCUDAStream();
device = s.device();
cuda_stream = s.stream();
moveStream = std::move(s);
REQUIRE(moveStream.device() == device);
REQUIRE(moveStream.stream() == cuda_stream);
}
REQUIRE(moveStream.internals());
REQUIRE(moveStream.device() == device);
REQUIRE(moveStream.stream() == cuda_stream);
}
TEST_CASE("Getting and Setting Streams", "Verifies streams are set properly") {
at::CUDAStream myStream = at::globalContext().createCUDAStream();
// Sets and gets
at::globalContext().setCurrentCUDAStream(myStream);
at::CUDAStream curStream = at::globalContext().getCurrentCUDAStream();
REQUIRE(myStream == curStream);
// Gets, sets, and gets default stream
at::CUDAStream defaultStream = at::globalContext().getDefaultCUDAStream();
at::globalContext().setCurrentCUDAStream(defaultStream);
curStream = at::globalContext().getCurrentCUDAStream();
REQUIRE(defaultStream != myStream);
REQUIRE(curStream == defaultStream);
}
TEST_CASE("Stream API retain/free", "Ensures streams are destroyed properly") {
auto ptr = at::detail::CUDAStream_createAndRetainWithOptions(
at::CUDAStream::DEFAULT_FLAGS
, at::CUDAStream::DEFAULT_PRIORITY);
at::detail::CUDAStream_free(ptr);
REQUIRE(ptr == nullptr);
}
void thread_fun(at::CUDAStream& cur_thread_stream) {
auto new_stream = at::globalContext().createCUDAStream();
at::globalContext().setCurrentCUDAStream(new_stream);
cur_thread_stream = at::globalContext().getCurrentCUDAStream();
REQUIRE(cur_thread_stream == new_stream);
}
TEST_CASE("Multithread Getting and Setting", "Ensures streams are thread local") {
at::CUDAStream s0, s1;
std::thread t0{thread_fun, std::ref(s0)};
std::thread t1{thread_fun, std::ref(s1)};
t0.join();
t1.join();
at::CUDAStream cur_stream = at::globalContext().getCurrentCUDAStream();
at::CUDAStream default_stream = at::globalContext().getDefaultCUDAStream();
REQUIRE(cur_stream == default_stream);
REQUIRE(cur_stream != s0);
REQUIRE(cur_stream != s1);
REQUIRE(s0 != s1);
}

View File

@ -12,7 +12,7 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
foreach(THC_FILE TensorSort TensorMathCompareT TensorMathPointwise TensorMathCompare TensorMathReduce TensorMasked)
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu")
FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu"
"#include \"../THC${THC_FILE}.cuh\"\n#include \"THCTensor.hpp\"\n#include \"THCStream.hpp\"\n#include \"../generic/THC${THC_FILE}.cu\"\n#include \"../THCGenerate${THC_TYPE}Type.h\"\n")
"#include \"../THC${THC_FILE}.cuh\"\n#include \"THCTensor.hpp\"\n#include \"../generic/THC${THC_FILE}.cu\"\n#include \"../THCGenerate${THC_TYPE}Type.h\"\n")
endif()
LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu")
endforeach()
@ -114,7 +114,6 @@ INSTALL(FILES
# See Note [TH abstraction violation]
THCGenerator.hpp
THCTensor.hpp
THCStream.hpp
THCStorage.hpp
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC")

View File

@ -1,5 +1,4 @@
#include "THCCachingAllocator.h"
#include "THCStream.hpp"
#include <cuda_runtime_api.h>
#include <algorithm>
@ -36,7 +35,6 @@
// work.
//
namespace {
typedef std::shared_ptr<THCStream> THCStreamPtr;
@ -302,7 +300,7 @@ struct THCCachingAllocator
if (!block) {
THError("invalid device pointer: %p", ptr);
}
if (stream->stream == block->stream) {
if (THCStream_stream(stream) == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
@ -440,14 +438,14 @@ struct THCCachingAllocator
for (auto it = streams.begin(); it != streams.end(); ++it) {
auto& stream = *it;
err = cudaSetDevice(stream->device);
err = cudaSetDevice(THCStream_device(stream.get()));
if (err != cudaSuccess) break;
cudaEvent_t event;
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
if (err != cudaSuccess) break;
err = cudaEventRecord(event, stream->stream);
err = cudaEventRecord(event, THCStream_stream(stream.get()));
if (err != cudaSuccess) break;
block->event_count++;

View File

@ -1,5 +1,5 @@
#include "THCCachingHostAllocator.h"
#include "THCStream.hpp"
#include "THCStream.h"
#include <cuda_runtime_api.h>
#include <deque>
@ -10,7 +10,6 @@
#include <unordered_map>
#include <utility>
namespace {
typedef std::shared_ptr<THCStream> THCStreamPtr;
@ -228,14 +227,14 @@ struct HostAllocator
for (auto it = streams.begin(); it != streams.end(); ++it) {
auto& stream = *it;
err = cudaSetDevice(stream->device);
err = cudaSetDevice(THCStream_device(stream.get()));
if (err != cudaSuccess) break;
cudaEvent_t event;
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
if (err != cudaSuccess) break;
err = cudaEventRecord(event, stream->stream);
err = cudaEventRecord(event, THCStream_stream(stream.get()));
if (err != cudaSuccess) break;
block.event_count++;

View File

@ -1,12 +1,13 @@
#include "THCGeneral.h"
#include "THCStream.hpp"
#include "TH.h"
#include "THCAllocator.h"
#include "THCCachingHostAllocator.h"
#include "THCStream.h"
#include "THCThreadLocal.h"
#include "THCTensorRandom.h"
#include "THCGeneral.hpp"
#include "ATen/CUDAStream.h"
#include <stdlib.h>
#include <stdint.h>
@ -75,11 +76,6 @@ void THCudaInit(THCState* state)
int device = 0;
THCudaCheck(cudaGetDevice(&device));
/* Start in the default stream on the current device */
state->currentStreams = (THCThreadLocal*) malloc(numDevices * sizeof(THCThreadLocal));
for (int i = 0; i < numDevices; ++i) {
state->currentStreams[i] = THCThreadLocal_alloc();
}
state->currentPerDeviceBlasHandle = THCThreadLocal_alloc();
state->currentPerDeviceSparseHandle = THCThreadLocal_alloc();
@ -180,8 +176,6 @@ void THCudaShutdown(THCState* state)
free(res->blasHandles);
free(res->sparseHandles);
THCStream_free((THCStream*)THCThreadLocal_get(state->currentStreams[dev]));
THCThreadLocal_free(state->currentStreams[dev]);
}
free(state->resourcesPerDevice);
if (state->cudaDeviceAllocator->emptyCache) {
@ -190,7 +184,6 @@ void THCudaShutdown(THCState* state)
if (state->cudaHostAllocator == &THCCachingHostAllocator) {
THCCachingHostAllocator_emptyCache();
}
free(state->currentStreams);
THCThreadLocal_free(state->currentPerDeviceBlasHandle);
THCudaCheck(cudaSetDevice(prevDev));
@ -431,51 +424,30 @@ cusparseHandle_t THCState_getDeviceSparseHandle(THCState *state, int device, int
return res->sparseHandles[handle - 1];
}
THCStream* THCState_getStreamOnDevice(THCState* state, int device)
{
THCThreadLocal local = state->currentStreams[device];
THCStream* stream = (THCStream*)THCThreadLocal_get(local);
if (!stream) {
stream = THCStream_defaultStream(device);
THCStream_retain(stream);
THCThreadLocal_set(local, stream);
}
return stream;
THCStream* THCState_getStreamOnDevice(THCState* state, int device) {
return at::detail::CUDAStream_getCurrentStreamOnDeviceUnsafe(device);
}
void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
{
THAssert(stream);
if (stream->device != device) {
THError("invalid stream; expected stream for device %d, but was on %d",
device, stream->device);
}
THCStream_retain(stream);
THCThreadLocal local = state->currentStreams[device];
THCStream_free((THCStream*)THCThreadLocal_get(local));
THCThreadLocal_set(local, stream);
void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream) {
at::detail::CUDAStream_setStreamOnDevice(device, stream);
}
cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device)
{
THCStream* stream = THCState_getStreamOnDevice(state, device);
THAssert(stream);
return stream->stream;
cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device) {
return at::detail::CUDAStream_stream(
at::detail::CUDAStream_getCurrentStreamOnDeviceUnsafe(device));
}
cudaStream_t THCState_getCurrentStream(THCState *state)
{
/* This is called at the point of kernel execution.
For some debugging code or improperly instrumented kernels,
`state` is null */
if (state) {
int device;
THCudaCheck(cudaGetDevice(&device));
return THCState_getCurrentStreamOnDevice(state, device);
} else {
/* assume default stream */
return NULL;
}
cudaStream_t THCState_getCurrentStream(THCState *state) {
return at::detail::CUDAStream_stream(
at::detail::CUDAStream_getCurrentStreamUnsafe());
}
THCStream* THCState_getStream(THCState *state) {
return at::detail::CUDAStream_getCurrentStreamUnsafe();
}
void THCState_setStream(THCState *state, THCStream *stream) {
at::detail::CUDAStream_setStream(stream);
}
cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
@ -528,20 +500,6 @@ int THCState_getCurrentSparseHandleIndex(THCState *state)
return (int) (intptr_t) value;
}
THCStream* THCState_getStream(THCState *state)
{
int device;
THCudaCheck(cudaGetDevice(&device));
return THCState_getStreamOnDevice(state, device);
}
void THCState_setStream(THCState *state, THCStream *stream)
{
int device;
THCudaCheck(cudaGetDevice(&device));
THCState_setStreamOnDevice(state, device, stream);
}
void THCState_setCurrentBlasHandleIndex(THCState *state, int handle)
{
if (handle > state->numUserBlasHandles || handle <= 0)
@ -732,8 +690,7 @@ void THCudaHostFree(THCState *state, void *ptr)
return allocator->free(NULL, ptr);
}
void THCudaHostRecord(THCState *state, void *ptr)
{
void THCudaHostRecord(THCState *state, void *ptr) {
if (state->cudaHostAllocator == &THCCachingHostAllocator) {
THCStream* stream = THCState_getStream(state);
THCCachingHostAllocator_recordEvent(ptr, stream);

View File

@ -45,7 +45,7 @@
#endif
struct THCRNGState; /* Random number generator state. */
typedef struct THCStream THCStream;
typedef struct CUDAStreamInternals THCStream;
typedef struct THCState THCState;
struct THCState;
@ -110,8 +110,9 @@ THC_API int THCState_getNumDevices(THCState* state);
/* Stream API */
THC_API cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device);
THC_API cudaStream_t THCState_getCurrentStream(THCState *state);
THC_API struct THCStream* THCState_getStream(THCState *state);
THC_API void THCState_setStream(THCState *state, struct THCStream* stream);
THC_API THCStream* THCState_getStream(THCState *state);
THC_API void THCState_setStream(THCState *state, THCStream* stream);
THC_API THCStream* THCState_getStreamOnDevice(THCState* state, int device);
THC_API void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream);

View File

@ -27,8 +27,6 @@ struct THCState {
/* Index of the current selected sparse handle. The actual sparse handle used
depends on the current device. */
THCThreadLocal/*<int>*/ currentPerDeviceSparseHandle;
/* Array of thread locals containing the current stream for each device */
THCThreadLocal* currentStreams;
/* Table of enabled peer-to-peer access between directed pairs of GPUs.
If i accessing allocs on j is enabled, p2pAccess[i][j] is 1; 0 otherwise. */

View File

@ -1,62 +1,32 @@
#include "THCStream.hpp"
#include "THCStream.h"
#include "ATen/CUDAStream.h"
#include <mutex>
#include <cuda_runtime_api.h>
#define MAX_DEVICES 256
static THCStream default_streams[MAX_DEVICES];
static void initialize_default_streams()
{
for (int i = 0; i < MAX_DEVICES; i++) {
default_streams[i].device = i;
}
THC_API THCStream* THCStream_defaultStream(int device) {
return at::detail::CUDAStream_getDefaultStreamOnDevice(device);
}
THCStream* THCStream_new(int flags)
{
THCStream* self = (THCStream*) malloc(sizeof(THCStream));
self->refcount = 1;
THCudaCheck(cudaGetDevice(&self->device));
THCudaCheck(cudaStreamCreateWithFlags(&self->stream, flags));
return self;
THC_API THCStream* THCStream_new(int flags) {
return THCStream_newWithPriority(flags, at::CUDAStream::DEFAULT_PRIORITY);
}
THC_API THCStream* THCStream_defaultStream(int device)
{
// default streams aren't refcounted
THAssert(device >= 0 && device < MAX_DEVICES);
std::once_flag once;
std::call_once(once, &initialize_default_streams);
return &default_streams[device];
THC_API THCStream* THCStream_newWithPriority(int flags, int priority) {
return at::detail::CUDAStream_createAndRetainWithOptions(flags, priority);
}
THC_API cudaStream_t THCStream_stream(THCStream* self) { return self->stream; }
THC_API int THCStream_device(THCStream* self) { return self->device; }
THCStream* THCStream_newWithPriority(int flags, int priority)
{
THCStream* self = (THCStream*) malloc(sizeof(THCStream));
self->refcount = 1;
THCudaCheck(cudaGetDevice(&self->device));
THCudaCheck(cudaStreamCreateWithPriority(&self->stream, flags, priority));
return self;
THC_API cudaStream_t THCStream_stream(THCStream* stream) {
return at::detail::CUDAStream_stream(stream);
}
void THCStream_free(THCStream* self)
{
if (!self || !self->stream) {
return;
}
if (--self->refcount == 0) {
THCudaCheckWarn(cudaStreamDestroy(self->stream));
free(self);
}
THC_API int THCStream_device(THCStream* stream) {
return at::detail::CUDAStream_device(stream);
}
void THCStream_retain(THCStream* self)
{
if (self->stream) {
self->refcount++;
}
THC_API void THCStream_retain(THCStream* stream) {
at::detail::CUDAStream_retain(stream);
}
THC_API void THCStream_free(THCStream* stream) {
at::detail::CUDAStream_free(stream);
}

View File

@ -1,17 +1,26 @@
#ifndef THC_STREAM_INC
#define THC_STREAM_INC
#include <cuda_runtime_api.h>
#include "THCGeneral.h"
struct THCStream;
/*
* Note: legacy API.
*
* Stream usage should be done through ATen/Context.h.
*/
typedef struct CUDAStreamInternals THCStream;
THC_API THCStream* THCStream_new(int flags);
THC_API cudaStream_t THCStream_stream(THCStream* self);
THC_API int THCStream_device(THCStream* self);
// Stream creation
THC_API THCStream* THCStream_defaultStream(int device);
THC_API THCStream* THCStream_new(int flags);
THC_API THCStream* THCStream_newWithPriority(int flags, int priority);
THC_API void THCStream_free(THCStream* self);
THC_API void THCStream_retain(THCStream* self);
// Getters
THC_API cudaStream_t THCStream_stream(THCStream*);
THC_API int THCStream_device(THCStream*);
// Memory management
THC_API void THCStream_retain(THCStream*);
THC_API void THCStream_free(THCStream*);
#endif // THC_STREAM_INC

View File

@ -1,14 +0,0 @@
#pragma once
// STOP!!! Thinking of including this header directly? Please
// read Note [TH abstraction violation]
#include <atomic>
#include "THCStream.h"
struct THCStream
{
cudaStream_t stream;
int device;
std::atomic<int> refcount;
};

View File

@ -1,6 +1,6 @@
#include "THCTensorCopy.h"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "THCCachingHostAllocator.h"
#include "generic/THCTensorCopy.cpp"

View File

@ -6,7 +6,7 @@
#include "THCTensorMath.cuh"
#include "THCThrustAllocator.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include <thrust/copy.h>
#include <thrust/count.h>

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMasked.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMasked.cu"
#include "../THCGenerateShortType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompare.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompare.cu"
#include "../THCGenerateShortType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathCompareT.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathCompareT.cu"
#include "../THCGenerateShortType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathPointwise.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathPointwise.cu"
#include "../THCGenerateShortType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorMathReduce.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorMathReduce.cu"
#include "../THCGenerateShortType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateByteType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateCharType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateDoubleType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateFloatType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateHalfType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateIntType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateLongType.h"

View File

@ -1,5 +1,5 @@
#include "../THCTensorSort.cuh"
#include "THCTensor.hpp"
#include "THCStream.hpp"
#include "THCStream.h"
#include "../generic/THCTensorSort.cu"
#include "../THCGenerateShortType.h"

View File

@ -129,7 +129,7 @@ void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor
THTensor_(data)(src),
THTensor_(nElement)(src) * sizeof(real),
cudaMemcpyHostToDevice,
stream->stream));
THCStream_stream(stream)));
THCudaCheck(THCCachingHostAllocator_recordEvent(THStorage_(data)(src->storage), stream));
@ -160,7 +160,7 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor
THCTensor_(data)(state, src),
THCTensor_(nElement)(state, src) * sizeof(real),
cudaMemcpyDeviceToHost,
stream->stream));
THCStream_stream(stream)));
THCudaCheck(THCCachingHostAllocator_recordEvent(THCStorage_(data)(state, src->storage), stream));

View File

@ -185,7 +185,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<real, unsigned int, DIMS><<<catGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, dimension, param.outputStride[dimension]);
CatArrayBatchedCopy<real, unsigned int, DIMS><<<catGrid, applyBlock, 0, THCStream_stream(stream)>>>(data, d_inputs, param, dimension, param.outputStride[dimension]);
// Now we loop
offset = 0;
@ -210,7 +210,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
stackInputs,
j * sizeof(CatArrInputTensor<real, unsigned int>),
cudaMemcpyHostToDevice,
stream->stream));
THCStream_stream(stream)));
THCudaHostRecord(state, stackInputs);
THCudaHostFree(state, stackInputs);

View File

@ -25,6 +25,9 @@ fi
if [[ -x ./apply_test ]]; then
./apply_test
fi
if [[ -x ./stream_test ]]; then
./stream_test
fi
if [ "$VALGRIND" == "ON" ]
then
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"

View File

@ -7,10 +7,7 @@
#include <sstream>
#include <ATen/ATen.h>
#include <THC/THC.h>
// See Note [TH abstraction violation]
// - Used to access 'stream' member
#include <THC/THCStream.hpp>
#include <THC/THCStream.h>
namespace torch { namespace cuda { namespace nccl {
@ -194,7 +191,7 @@ void broadcast(TensorList tensors, const stream_list& streams, const comm_list&
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
device_guard.set_index(tensors[i].get_device());
// TODO: use current stream
const auto stream = (streams.empty() || !streams[i]) ? NULL : streams[i]->stream;
const auto stream = (streams.empty() || !streams[i]) ? NULL : THCStream_stream(streams[i]);
CHECK(ncclBcast(tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
}
#else

View File

@ -1,13 +1,12 @@
#pragma once
typedef struct CUDAStreamInternals THCStream;
#include <algorithm>
#include <cuda.h>
#include <cuda_runtime.h>
// Forward declaration
struct THCStream;
namespace c10d {
// RAII wrapper for current CUDA device.