mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[RELAND] [CUDA graphs] Avoid sync errors when graph capturing cudnn rnn calls that use cudnn dropout (#57373)
Summary: https://github.com/pytorch/pytorch/pull/56433 was reverted because the test perceived internal dropout state creation as a memory leak. This PR resubmits with the leak check skipped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57373 Reviewed By: anjali411 Differential Revision: D28152186 Pulled By: ezyang fbshipit-source-id: 9a593fcdbbabbb09dc4e4221191663e94b697503
This commit is contained in:
parent
1b745efbe8
commit
e841f335aa
|
|
@ -2,6 +2,7 @@
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/cuda/CUDAConfig.h>
|
#include <ATen/cuda/CUDAConfig.h>
|
||||||
#include <ATen/cuda/CUDAEvent.h>
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
|
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <ATen/InitialTensorOptions.h>
|
#include <ATen/InitialTensorOptions.h>
|
||||||
#include <ATen/MatrixRef.h>
|
#include <ATen/MatrixRef.h>
|
||||||
|
|
@ -1373,6 +1374,30 @@ std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor&
|
||||||
return std::make_tuple(hx, cx);
|
return std::make_tuple(hx, cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note [DropoutState and CUDA graph capture]
|
||||||
|
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
* (1) Telling a capturing stream to wait on an event recorded in a non-capturing stream is an error.
|
||||||
|
* (2) Telling a non-capturing stream to wait on an event recorded during capture is also an error.
|
||||||
|
*
|
||||||
|
* So DropoutState's usage syncs could error if an RNN with dropout is called in an uncaptured region
|
||||||
|
* then called in a captured region (triggering 1), or called in a captured region then called
|
||||||
|
# in an uncaptured region (triggering 2).
|
||||||
|
*
|
||||||
|
* To prevent 1 and 2, lock() only syncs on the last usage event if it was recorded in the same
|
||||||
|
* capture state as the current state (which also means the same graph, if capture is in progress).
|
||||||
|
*
|
||||||
|
* The solution should be safe as long as capture obeys the following restrictions:
|
||||||
|
* - Only one capture may be underway at a time in a given process.
|
||||||
|
* - While a capture is underway, no calls to eager ops on noncapturing streams (on any thread)
|
||||||
|
* may interleave with the captured ops.
|
||||||
|
*
|
||||||
|
* TODO: As people experiment with capture, keep an eye out for use cases that might need to
|
||||||
|
* relax those restrictions.
|
||||||
|
*
|
||||||
|
* See https://github.com/pytorch/pytorch/pull/56433 for more discussion.
|
||||||
|
*/
|
||||||
|
|
||||||
struct DropoutState {
|
struct DropoutState {
|
||||||
// Both buffer and event are lazily instantiated when a dropout state is needed
|
// Both buffer and event are lazily instantiated when a dropout state is needed
|
||||||
// for the first time. Note that in this case needed != used, as we don't need
|
// for the first time. Note that in this case needed != used, as we don't need
|
||||||
|
|
@ -1380,6 +1405,12 @@ struct DropoutState {
|
||||||
at::Tensor buffer;
|
at::Tensor buffer;
|
||||||
c10::optional<cuda::CUDAEvent> event;
|
c10::optional<cuda::CUDAEvent> event;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
#if CUDA_VERSION >= 11000
|
||||||
|
// cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can serve
|
||||||
|
// as a sentinel value that capture was not underway.
|
||||||
|
cuda::CaptureId_t capture_id_last_lock = 0;
|
||||||
|
cuda::CaptureId_t capture_id_last_unlock = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Every time we use a dropout state, we need to synchronize with its event,
|
// Every time we use a dropout state, we need to synchronize with its event,
|
||||||
// to make sure all previous uses finish running before this one starts. Once
|
// to make sure all previous uses finish running before this one starts. Once
|
||||||
|
|
@ -1392,13 +1423,38 @@ struct DropoutState {
|
||||||
// could then define it before we get to unlock().
|
// could then define it before we get to unlock().
|
||||||
mutex.lock();
|
mutex.lock();
|
||||||
if (event) {
|
if (event) {
|
||||||
|
#if CUDA_VERSION >= 11000
|
||||||
|
// See Note [DropoutState and CUDA graph capture]
|
||||||
|
cudaStreamCaptureStatus status;
|
||||||
|
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
|
||||||
|
&status,
|
||||||
|
&capture_id_last_lock));
|
||||||
|
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
|
||||||
|
capture_id_last_lock = 0;
|
||||||
|
}
|
||||||
|
if (capture_id_last_lock == capture_id_last_unlock) {
|
||||||
|
event->block(cuda::getCurrentCUDAStream());
|
||||||
|
}
|
||||||
|
#else
|
||||||
event->block(cuda::getCurrentCUDAStream());
|
event->block(cuda::getCurrentCUDAStream());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void unlock() {
|
void unlock() {
|
||||||
if (event) {
|
if (event) {
|
||||||
event->record();
|
event->record();
|
||||||
|
#if CUDA_VERSION >= 11000
|
||||||
|
// See Note [DropoutState and CUDA graph capture]
|
||||||
|
cudaStreamCaptureStatus status;
|
||||||
|
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
|
||||||
|
&status,
|
||||||
|
&capture_id_last_unlock));
|
||||||
|
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
|
||||||
|
capture_id_last_unlock = 0;
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
mutex.unlock();
|
mutex.unlock();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, t
|
||||||
_compare_trilu_indices, _compare_large_trilu_indices
|
_compare_trilu_indices, _compare_large_trilu_indices
|
||||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
||||||
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
|
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
|
||||||
slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY
|
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY
|
||||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists
|
from torch.testing._internal.autocast_test_lists import AutocastTestLists
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
|
|
@ -3457,6 +3457,33 @@ torch.cuda.synchronize()
|
||||||
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
|
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
|
||||||
c = torch.zeros((3,), device="cuda")
|
c = torch.zeros((3,), device="cuda")
|
||||||
|
|
||||||
|
@unittest.skipIf((not TEST_CUDA) or
|
||||||
|
TEST_WITH_ROCM or
|
||||||
|
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||||
|
# If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
|
||||||
|
# DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
|
||||||
|
# as a memory leak unless we skip the leak check.
|
||||||
|
@skipCUDAMemoryLeakCheckIf(True)
|
||||||
|
def test_graph_cudnn_dropout(self):
|
||||||
|
# Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
|
||||||
|
# In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
|
||||||
|
# avoid syncing noncapturing streams with captured events or vice versa.
|
||||||
|
model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
|
||||||
|
x = torch.ones(100, 192, 512, device="cuda")
|
||||||
|
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
|
g = torch.cuda._Graph()
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
g.capture_begin()
|
||||||
|
y = model(x)
|
||||||
|
g.capture_end()
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
@unittest.skipIf((not TEST_CUDA) or
|
@unittest.skipIf((not TEST_CUDA) or
|
||||||
TEST_WITH_ROCM or
|
TEST_WITH_ROCM or
|
||||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user