mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[reland] Allow external CUDA streams to be set as current (#66324)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66324 Fixes https://github.com/pytorch/pytorch/issues/65822. Reland of https://github.com/pytorch/pytorch/pull/65914. ghstack-source-id: 140105651 Test Plan: Added tests Reviewed By: ngimel Differential Revision: D31506134 fbshipit-source-id: ff56203a120befdb282e974309478ac11aa56652
This commit is contained in:
parent
355acfdebc
commit
bc06eefebe
|
|
@ -11,6 +11,7 @@
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <future>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
|
@ -294,3 +295,141 @@ TEST(TestStream, GenericVirtualCUDAEventTest) {
|
||||||
ASSERT_TRUE(event.query());
|
ASSERT_TRUE(event.query());
|
||||||
ASSERT_TRUE(event.flag() == c10::EventFlag::PYTORCH_DEFAULT);
|
ASSERT_TRUE(event.flag() == c10::EventFlag::PYTORCH_DEFAULT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verifies external streams can be created and used
|
||||||
|
TEST(TestStream, ExternalTest) {
|
||||||
|
if (!at::cuda::is_available())
|
||||||
|
return;
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
|
||||||
|
cudaStream_t cuda_stream;
|
||||||
|
cudaStreamCreateWithPriority(&cuda_stream, cudaStreamNonBlocking, -1);
|
||||||
|
|
||||||
|
at::cuda::CUDAStream myStream =
|
||||||
|
at::cuda::getStreamFromExternal(cuda_stream, 0);
|
||||||
|
|
||||||
|
at::cuda::setCurrentCUDAStream(myStream);
|
||||||
|
at::cuda::CUDAStream curStream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
ASSERT_EQ_CUDA(curStream, myStream);
|
||||||
|
ASSERT_EQ_CUDA(curStream.stream(), cuda_stream);
|
||||||
|
|
||||||
|
cudaStreamDestroy(cuda_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies different external streams can be used for different devices at the
|
||||||
|
// same time
|
||||||
|
TEST(TestStream, ExternalMultiDeviceTest) {
|
||||||
|
if (!at::cuda::is_available())
|
||||||
|
return;
|
||||||
|
if (at::cuda::getNumGPUs() < 2)
|
||||||
|
return;
|
||||||
|
cudaStream_t cuda_stream_0;
|
||||||
|
cudaStream_t cuda_stream_1;
|
||||||
|
{
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
cudaStreamCreateWithPriority(&cuda_stream_0, cudaStreamNonBlocking, -1);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
at::cuda::CUDAGuard device_guard(1);
|
||||||
|
cudaStreamCreateWithPriority(&cuda_stream_1, cudaStreamNonBlocking, -1);
|
||||||
|
}
|
||||||
|
at::cuda::CUDAStream myStream0 =
|
||||||
|
at::cuda::getStreamFromExternal(cuda_stream_0, 0);
|
||||||
|
at::cuda::CUDAStream myStream1 =
|
||||||
|
at::cuda::getStreamFromExternal(cuda_stream_1, 1);
|
||||||
|
|
||||||
|
at::cuda::setCurrentCUDAStream(myStream0);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(0), myStream0);
|
||||||
|
at::cuda::setCurrentCUDAStream(myStream1);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(0), myStream0);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(1), myStream1);
|
||||||
|
|
||||||
|
cudaStreamDestroy(cuda_stream_0);
|
||||||
|
cudaStreamDestroy(cuda_stream_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies external streams work with guards, even nested ones
|
||||||
|
TEST(TestStream, ExternalGuardTest) {
|
||||||
|
if (!at::cuda::is_available())
|
||||||
|
return;
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
|
||||||
|
cudaStream_t a_cuda_stream;
|
||||||
|
cudaStream_t another_cuda_stream;
|
||||||
|
cudaStreamCreateWithPriority(&a_cuda_stream, cudaStreamNonBlocking, -1);
|
||||||
|
cudaStreamCreateWithPriority(&another_cuda_stream, cudaStreamNonBlocking, -1);
|
||||||
|
at::cuda::CUDAStream myFirstStream =
|
||||||
|
at::cuda::getStreamFromExternal(a_cuda_stream, 0);
|
||||||
|
at::cuda::CUDAStream mySecondStream =
|
||||||
|
at::cuda::getStreamFromExternal(another_cuda_stream, 0);
|
||||||
|
|
||||||
|
at::cuda::CUDAStream originalStream = at::cuda::getCurrentCUDAStream();
|
||||||
|
{
|
||||||
|
at::cuda::CUDAStreamGuard outerGuard(myFirstStream);
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.original_stream(), originalStream);
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.current_stream(), myFirstStream);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(), myFirstStream);
|
||||||
|
{
|
||||||
|
at::cuda::CUDAStreamGuard innerGuard(mySecondStream);
|
||||||
|
ASSERT_EQ_CUDA(innerGuard.original_stream(), myFirstStream);
|
||||||
|
ASSERT_EQ_CUDA(innerGuard.current_stream(), mySecondStream);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(), mySecondStream);
|
||||||
|
}
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.original_stream(), originalStream);
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.current_stream(), myFirstStream);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(), myFirstStream);
|
||||||
|
outerGuard.reset_stream(mySecondStream);
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.original_stream(), originalStream);
|
||||||
|
ASSERT_EQ_CUDA(outerGuard.current_stream(), mySecondStream);
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(), mySecondStream);
|
||||||
|
}
|
||||||
|
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(), originalStream);
|
||||||
|
|
||||||
|
cudaStreamDestroy(a_cuda_stream);
|
||||||
|
cudaStreamDestroy(another_cuda_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies that different threads stage their external streams to different
|
||||||
|
// places in memory and thus don't interfere
|
||||||
|
TEST(TestStream, ExternalMultiThreadTest) {
|
||||||
|
if (!at::cuda::is_available())
|
||||||
|
return;
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
|
||||||
|
cudaStream_t cuda_stream_a;
|
||||||
|
cudaStream_t cuda_stream_b;
|
||||||
|
cudaStreamCreateWithPriority(&cuda_stream_a, cudaStreamNonBlocking, -1);
|
||||||
|
cudaStreamCreateWithPriority(&cuda_stream_b, cudaStreamNonBlocking, -1);
|
||||||
|
at::cuda::CUDAStream myStreamA =
|
||||||
|
at::cuda::getStreamFromExternal(cuda_stream_a, 0);
|
||||||
|
at::cuda::CUDAStream myStreamB =
|
||||||
|
at::cuda::getStreamFromExternal(cuda_stream_b, 0);
|
||||||
|
|
||||||
|
std::promise<void> aToBProm;
|
||||||
|
std::promise<void> bToAProm;
|
||||||
|
c10::optional<at::cuda::CUDAStream> foundStream;
|
||||||
|
|
||||||
|
std::thread threadA([&]() {
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
at::cuda::setCurrentCUDAStream(myStreamA);
|
||||||
|
aToBProm.set_value();
|
||||||
|
bToAProm.get_future().wait();
|
||||||
|
foundStream = at::cuda::getCurrentCUDAStream();
|
||||||
|
});
|
||||||
|
|
||||||
|
std::thread threadB([&]() {
|
||||||
|
at::cuda::CUDAGuard device_guard(0);
|
||||||
|
aToBProm.get_future().wait();
|
||||||
|
at::cuda::setCurrentCUDAStream(myStreamB);
|
||||||
|
bToAProm.set_value();
|
||||||
|
});
|
||||||
|
|
||||||
|
threadA.join();
|
||||||
|
threadB.join();
|
||||||
|
|
||||||
|
ASSERT_EQ_CUDA(*foundStream, myStreamA);
|
||||||
|
|
||||||
|
cudaStreamDestroy(cuda_stream_a);
|
||||||
|
cudaStreamDestroy(cuda_stream_b);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
@ -16,28 +15,8 @@ namespace cuda {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Internal implementation that leaks the stream. It's not intended to be used
|
|
||||||
// outside of this file.
|
|
||||||
struct LeakyStreamInternals {
|
|
||||||
LeakyStreamInternals() = default;
|
|
||||||
C10_DISABLE_COPY_AND_ASSIGN(LeakyStreamInternals);
|
|
||||||
|
|
||||||
~LeakyStreamInternals() {
|
|
||||||
// NB: this code is invoked only in the destruction of global variables
|
|
||||||
// (since we never shrink the corresponding vectors). At this point the CUDA
|
|
||||||
// runtime might be already destroyed and invoking cudaStreamDestroy leads
|
|
||||||
// to a crash. It's likely an issue in CUDA, but to be safe - let's just
|
|
||||||
// "forget" the destruction.
|
|
||||||
|
|
||||||
// if (stream) cudaStreamDestroy(stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceIndex device_index = -1;
|
|
||||||
int32_t stream_id = -1;
|
|
||||||
cudaStream_t stream = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Global stream state and constants
|
// Global stream state and constants
|
||||||
|
static std::once_flag init_flag;
|
||||||
static DeviceIndex num_gpus = -1;
|
static DeviceIndex num_gpus = -1;
|
||||||
static constexpr int kStreamsPerPoolBits = 5;
|
static constexpr int kStreamsPerPoolBits = 5;
|
||||||
static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
|
static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
|
||||||
|
|
@ -45,12 +24,8 @@ static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
|
||||||
static constexpr int kStreamTypeBits = 3;
|
static constexpr int kStreamTypeBits = 3;
|
||||||
|
|
||||||
// Note: lower numbers are higher priorities, zero is default priority
|
// Note: lower numbers are higher priorities, zero is default priority
|
||||||
static int kHighPriority = -1;
|
static constexpr int kHighPriority = -1;
|
||||||
static int kLowPriority = 0;
|
static constexpr int kLowPriority = 0;
|
||||||
|
|
||||||
// Default streams
|
|
||||||
static std::once_flag init_flag;
|
|
||||||
static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS];
|
|
||||||
|
|
||||||
// Non-default streams
|
// Non-default streams
|
||||||
// Note: the number of CUDA devices is determined at run time,
|
// Note: the number of CUDA devices is determined at run time,
|
||||||
|
|
@ -60,16 +35,18 @@ static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS];
|
||||||
// the low and high priority counters track, for each device, the next stream
|
// the low and high priority counters track, for each device, the next stream
|
||||||
// in the pool to be returned when a stream is requested (round-robin fashion
|
// in the pool to be returned when a stream is requested (round-robin fashion
|
||||||
// , see the note in CUDAStream.h).
|
// , see the note in CUDAStream.h).
|
||||||
//
|
// The streams are "leaked": they are created but never destroyed because the
|
||||||
// unique_ptr<T[]> is used instead of vector<T> because T might be non-movable
|
// destruction of global variables could happen after the CUDA runtime has
|
||||||
// and non-copyable.
|
// already been destroyed and thus invoking cudaStreamDestroy could lead to a
|
||||||
|
// crash. It's likely an issue in CUDA, but to be safe - let's just "forget"
|
||||||
|
// the destruction.
|
||||||
static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
|
static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
|
||||||
static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
|
static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
|
||||||
static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
|
static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
|
||||||
static std::array<LeakyStreamInternals, kStreamsPerPool>
|
static cudaStream_t low_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
|
||||||
low_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
|
[kStreamsPerPool];
|
||||||
static std::array<LeakyStreamInternals, kStreamsPerPool>
|
static cudaStream_t high_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
|
||||||
high_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
|
[kStreamsPerPool];
|
||||||
|
|
||||||
// Note [StreamId assignment]
|
// Note [StreamId assignment]
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
@ -155,60 +132,10 @@ StreamId makeStreamId(StreamIdType st, size_t si) {
|
||||||
static_cast<StreamId>(st);
|
static_cast<StreamId>(st);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename A>
|
|
||||||
static bool pointer_within(const T* ptr, const A& arr) {
|
|
||||||
return std::greater_equal<const T*>()(ptr, arr.data()) &&
|
|
||||||
std::less<const T*>()(ptr, arr.data() + arr.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
static StreamId CUDAStream_getStreamId(const LeakyStreamInternals* ptr) {
|
|
||||||
// Hypothetically, we could store the stream ID in the stream. But that
|
|
||||||
// introduces a degree of freedom which could lead to bugs (where we
|
|
||||||
// misnumber streams in the pool, or overwrite the number). Better
|
|
||||||
// to just compute it based on the metric that actually matters,
|
|
||||||
// which is how we map IDs back into the vectors.
|
|
||||||
|
|
||||||
DeviceIndex device_index = ptr->device_index;
|
|
||||||
|
|
||||||
// Check if it's the default stream
|
|
||||||
if (ptr == &default_streams[device_index]) {
|
|
||||||
return makeStreamId(StreamIdType::DEFAULT, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's a low priority stream
|
|
||||||
// NB: Because ptr may not necessarily lie within the array, we must use
|
|
||||||
// std::less and similar templates to avoid UB that arises when
|
|
||||||
// doing an operator< comparison.
|
|
||||||
if (pointer_within<LeakyStreamInternals>(
|
|
||||||
ptr, low_priority_streams[device_index])) {
|
|
||||||
return makeStreamId(
|
|
||||||
StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's a high priority stream
|
|
||||||
if (pointer_within<LeakyStreamInternals>(
|
|
||||||
ptr, high_priority_streams[device_index])) {
|
|
||||||
return makeStreamId(
|
|
||||||
StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
0,
|
|
||||||
"Could not compute stream ID for ",
|
|
||||||
ptr,
|
|
||||||
" on device ",
|
|
||||||
device_index,
|
|
||||||
" (something has gone horribly wrong!)");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Thread-local current streams
|
// Thread-local current streams
|
||||||
static thread_local LeakyStreamInternals** current_streams = nullptr;
|
static thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
|
||||||
|
|
||||||
// Populates global values and creates a default stream for each device.
|
// Populates global values.
|
||||||
// Note: the default stream on each device is signified by a nullptr,
|
|
||||||
// and so is not created as usual.
|
|
||||||
// In particular, we don't need to switch devices when creating the
|
|
||||||
// streams.
|
|
||||||
// Warning: this function must only be called once!
|
// Warning: this function must only be called once!
|
||||||
static void initGlobalStreamState() {
|
static void initGlobalStreamState() {
|
||||||
num_gpus = device_count();
|
num_gpus = device_count();
|
||||||
|
|
@ -220,13 +147,6 @@ static void initGlobalStreamState() {
|
||||||
"max number of gpus expected (",
|
"max number of gpus expected (",
|
||||||
C10_COMPILE_TIME_MAX_GPUS,
|
C10_COMPILE_TIME_MAX_GPUS,
|
||||||
"). Increase that and recompile.");
|
"). Increase that and recompile.");
|
||||||
|
|
||||||
// Initializes default streams
|
|
||||||
for (const auto i : c10::irange(num_gpus)) {
|
|
||||||
default_streams[i].device_index = i;
|
|
||||||
low_priority_counters[i] = 0;
|
|
||||||
high_priority_counters[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates the low and high priority stream pools for the specified device
|
// Creates the low and high priority stream pools for the specified device
|
||||||
|
|
@ -240,14 +160,14 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||||
auto& lowpri_stream = low_priority_streams[device_index][i];
|
auto& lowpri_stream = low_priority_streams[device_index][i];
|
||||||
auto& hipri_stream = high_priority_streams[device_index][i];
|
auto& hipri_stream = high_priority_streams[device_index][i];
|
||||||
|
|
||||||
lowpri_stream.device_index = device_index;
|
|
||||||
hipri_stream.device_index = device_index;
|
|
||||||
|
|
||||||
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
|
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
|
||||||
&lowpri_stream.stream, kDefaultFlags, kLowPriority));
|
&lowpri_stream, kDefaultFlags, kLowPriority));
|
||||||
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
|
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
|
||||||
&hipri_stream.stream, kDefaultFlags, kHighPriority));
|
&hipri_stream, kDefaultFlags, kHighPriority));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
low_priority_counters[device_index] = 0;
|
||||||
|
high_priority_counters[device_index] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init front-end to ensure initialization only occurs once
|
// Init front-end to ensure initialization only occurs once
|
||||||
|
|
@ -260,10 +180,9 @@ static void initCUDAStreamsOnce() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inits current streams (thread local) to default streams
|
// Inits current streams (thread local) to default streams
|
||||||
current_streams =
|
current_streams = std::make_unique<StreamId[]>(num_gpus);
|
||||||
(LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
|
|
||||||
for (const auto i : c10::irange(num_gpus)) {
|
for (const auto i : c10::irange(num_gpus)) {
|
||||||
current_streams[i] = &default_streams[i];
|
current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -279,62 +198,52 @@ static uint32_t get_idx(std::atomic<uint32_t>& counter) {
|
||||||
return raw_idx % kStreamsPerPool;
|
return raw_idx % kStreamsPerPool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUDAStream CUDAStreamForId(DeviceIndex device_index, StreamId stream_id) {
|
||||||
|
return CUDAStream(
|
||||||
|
CUDAStream::UNCHECKED,
|
||||||
|
Stream(
|
||||||
|
Stream::UNSAFE,
|
||||||
|
c10::Device(DeviceType::CUDA, device_index),
|
||||||
|
stream_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
// See Note [StreamId assignment]
|
// See Note [StreamId assignment]
|
||||||
LeakyStreamInternals* CUDAStream_internals(CUDAStream s) {
|
cudaStream_t CUDAStream::stream() const {
|
||||||
c10::DeviceIndex device_index = s.device_index();
|
c10::DeviceIndex device_index = stream_.device_index();
|
||||||
StreamIdType st = streamIdType(s.unwrap().id());
|
StreamId stream_id = stream_.id();
|
||||||
size_t si = streamIdIndex(s.unwrap().id());
|
StreamIdType st = streamIdType(stream_id);
|
||||||
|
size_t si = streamIdIndex(stream_id);
|
||||||
switch (st) {
|
switch (st) {
|
||||||
case StreamIdType::DEFAULT:
|
case StreamIdType::DEFAULT:
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
si == 0,
|
si == 0,
|
||||||
"Unrecognized stream ",
|
"Unrecognized stream ",
|
||||||
s.unwrap(),
|
stream_,
|
||||||
" (I think this should be the default stream, but I got a non-zero index ",
|
" (I think this should be the default stream, but I got a non-zero index ",
|
||||||
si,
|
si,
|
||||||
").",
|
").",
|
||||||
" Did you manufacture the StreamId yourself? Don't do that; use the",
|
" Did you manufacture the StreamId yourself? Don't do that; use the",
|
||||||
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
|
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
|
||||||
return &default_streams[device_index];
|
return nullptr;
|
||||||
case StreamIdType::LOW:
|
case StreamIdType::LOW:
|
||||||
return &low_priority_streams[device_index][si];
|
return low_priority_streams[device_index][si];
|
||||||
case StreamIdType::HIGH:
|
case StreamIdType::HIGH:
|
||||||
return &high_priority_streams[device_index][si];
|
return high_priority_streams[device_index][si];
|
||||||
|
case StreamIdType::EXT:
|
||||||
|
return reinterpret_cast<cudaStream_t>(stream_id);
|
||||||
default:
|
default:
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
0,
|
0,
|
||||||
"Unrecognized stream ",
|
"Unrecognized stream ",
|
||||||
s.unwrap(),
|
stream_,
|
||||||
" (I didn't recognize the stream type, ",
|
" (I didn't recognize the stream type, ",
|
||||||
st,
|
st,
|
||||||
")");
|
")");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDAStream CUDAStream_fromInternals(const LeakyStreamInternals* ptr) {
|
|
||||||
return CUDAStream(
|
|
||||||
CUDAStream::UNCHECKED,
|
|
||||||
Stream(
|
|
||||||
Stream::UNSAFE,
|
|
||||||
c10::Device(DeviceType::CUDA, ptr->device_index),
|
|
||||||
CUDAStream_getStreamId(ptr)));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
cudaStream_t CUDAStream::stream() const {
|
|
||||||
int64_t stream_id = unwrap().id();
|
|
||||||
if (streamIdType(stream_id) == StreamIdType::EXT) {
|
|
||||||
// In this case this is a externally allocated stream
|
|
||||||
// we don't need to manage its life cycle
|
|
||||||
return reinterpret_cast<cudaStream_t>(stream_id);
|
|
||||||
} else {
|
|
||||||
auto ptr = CUDAStream_internals(*this);
|
|
||||||
TORCH_INTERNAL_ASSERT(ptr);
|
|
||||||
return ptr->stream;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a stream from the requested pool
|
// Returns a stream from the requested pool
|
||||||
// Note: when called the first time on a device, this will create the
|
// Note: when called the first time on a device, this will create the
|
||||||
// stream pools for that device.
|
// stream pools for that device.
|
||||||
|
|
@ -352,23 +261,18 @@ CUDAStream getStreamFromPool(
|
||||||
|
|
||||||
if (isHighPriority) {
|
if (isHighPriority) {
|
||||||
const auto idx = get_idx(high_priority_counters[device_index]);
|
const auto idx = get_idx(high_priority_counters[device_index]);
|
||||||
return CUDAStream_fromInternals(&high_priority_streams[device_index][idx]);
|
return CUDAStreamForId(device_index, makeStreamId(StreamIdType::HIGH, idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto idx = get_idx(low_priority_counters[device_index]);
|
const auto idx = get_idx(low_priority_counters[device_index]);
|
||||||
return CUDAStream_fromInternals(&low_priority_streams[device_index][idx]);
|
return CUDAStreamForId(device_index, makeStreamId(StreamIdType::LOW, idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDAStream getStreamFromExternal(
|
CUDAStream getStreamFromExternal(
|
||||||
cudaStream_t ext_stream,
|
cudaStream_t ext_stream,
|
||||||
DeviceIndex device_index) {
|
DeviceIndex device_index) {
|
||||||
return CUDAStream(
|
// The stream pointer will be the actual id
|
||||||
CUDAStream::UNCHECKED,
|
return CUDAStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
|
||||||
// The stream pointer will be the actual id
|
|
||||||
Stream(
|
|
||||||
Stream::UNSAFE,
|
|
||||||
c10::Device(DeviceType::CUDA, device_index),
|
|
||||||
reinterpret_cast<int64_t>(ext_stream)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
|
CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
|
||||||
|
|
@ -377,22 +281,21 @@ CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
|
||||||
device_index = current_device();
|
device_index = current_device();
|
||||||
}
|
}
|
||||||
check_gpu(device_index);
|
check_gpu(device_index);
|
||||||
return CUDAStream_fromInternals(&default_streams[device_index]);
|
return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
|
CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
|
||||||
initCUDAStreamsOnce();
|
initCUDAStreamsOnce();
|
||||||
if (device_index == -1) {
|
if (device_index == -1) {
|
||||||
device_index = current_device();
|
device_index = current_device();
|
||||||
}
|
}
|
||||||
check_gpu(device_index);
|
check_gpu(device_index);
|
||||||
return CUDAStream_fromInternals(current_streams[device_index]);
|
return CUDAStreamForId(device_index, current_streams[device_index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setCurrentCUDAStream(CUDAStream stream) {
|
void setCurrentCUDAStream(CUDAStream stream) {
|
||||||
initCUDAStreamsOnce();
|
initCUDAStreamsOnce();
|
||||||
auto ptr = CUDAStream_internals(stream);
|
current_streams[stream.device_index()] = stream.id();
|
||||||
TORCH_INTERNAL_ASSERT(ptr);
|
|
||||||
current_streams[ptr->device_index] = ptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
|
std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user