[RELAND] [CUDA graphs] Avoid sync errors when graph capturing cudnn rnn calls that use cudnn dropout (#57373)

Summary:
https://github.com/pytorch/pytorch/pull/56433 was reverted because the test perceived internal dropout state creation as a memory leak. This PR resubmits with the leak check skipped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57373

Reviewed By: anjali411

Differential Revision: D28152186

Pulled By: ezyang

fbshipit-source-id: 9a593fcdbbabbb09dc4e4221191663e94b697503
This commit is contained in:
Michael Carilli 2021-05-03 11:40:44 -07:00 committed by Facebook GitHub Bot
parent 1b745efbe8
commit e841f335aa
2 changed files with 84 additions and 1 deletions

View File

@ -2,6 +2,7 @@
#include <ATen/Config.h> #include <ATen/Config.h>
#include <ATen/cuda/CUDAConfig.h> #include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/CUDAEvent.h> #include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/InitialTensorOptions.h> #include <ATen/InitialTensorOptions.h>
#include <ATen/MatrixRef.h> #include <ATen/MatrixRef.h>
@ -1373,6 +1374,30 @@ std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor&
return std::make_tuple(hx, cx); return std::make_tuple(hx, cx);
} }
/**
* Note [DropoutState and CUDA graph capture]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* (1) Telling a capturing stream to wait on an event recorded in a non-capturing stream is an error.
* (2) Telling a non-capturing stream to wait on an event recorded during capture is also an error.
*
* So DropoutState's usage syncs could error if an RNN with dropout is called in an uncaptured region
* then called in a captured region (triggering 1), or called in a captured region then called
# in an uncaptured region (triggering 2).
*
* To prevent 1 and 2, lock() only syncs on the last usage event if it was recorded in the same
* capture state as the current state (which also means the same graph, if capture is in progress).
*
* The solution should be safe as long as capture obeys the following restrictions:
* - Only one capture may be underway at a time in a given process.
* - While a capture is underway, no calls to eager ops on noncapturing streams (on any thread)
* may interleave with the captured ops.
*
* TODO: As people experiment with capture, keep an eye out for use cases that might need to
* relax those restrictions.
*
* See https://github.com/pytorch/pytorch/pull/56433 for more discussion.
*/
struct DropoutState { struct DropoutState {
// Both buffer and event are lazily instantiated when a dropout state is needed // Both buffer and event are lazily instantiated when a dropout state is needed
// for the first time. Note that in this case needed != used, as we don't need // for the first time. Note that in this case needed != used, as we don't need
@ -1380,6 +1405,12 @@ struct DropoutState {
at::Tensor buffer; at::Tensor buffer;
c10::optional<cuda::CUDAEvent> event; c10::optional<cuda::CUDAEvent> event;
std::mutex mutex; std::mutex mutex;
#if CUDA_VERSION >= 11000
// cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can serve
// as a sentinel value that capture was not underway.
cuda::CaptureId_t capture_id_last_lock = 0;
cuda::CaptureId_t capture_id_last_unlock = 0;
#endif
// Every time we use a dropout state, we need to synchronize with its event, // Every time we use a dropout state, we need to synchronize with its event,
// to make sure all previous uses finish running before this one starts. Once // to make sure all previous uses finish running before this one starts. Once
@ -1392,13 +1423,38 @@ struct DropoutState {
// could then define it before we get to unlock(). // could then define it before we get to unlock().
mutex.lock(); mutex.lock();
if (event) { if (event) {
#if CUDA_VERSION >= 11000
// See Note [DropoutState and CUDA graph capture]
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
&status,
&capture_id_last_lock));
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
capture_id_last_lock = 0;
}
if (capture_id_last_lock == capture_id_last_unlock) {
event->block(cuda::getCurrentCUDAStream());
}
#else
event->block(cuda::getCurrentCUDAStream()); event->block(cuda::getCurrentCUDAStream());
#endif
} }
} }
void unlock() { void unlock() {
if (event) { if (event) {
event->record(); event->record();
#if CUDA_VERSION >= 11000
// See Note [DropoutState and CUDA graph capture]
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
&status,
&capture_id_last_unlock));
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
capture_id_last_unlock = 0;
}
TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock);
#endif
} }
mutex.unlock(); mutex.unlock();
} }

View File

@ -25,7 +25,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, t
_compare_trilu_indices, _compare_large_trilu_indices _compare_trilu_indices, _compare_large_trilu_indices
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \ from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \ NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY
from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.testing._internal.autocast_test_lists import AutocastTestLists
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
@ -3457,6 +3457,33 @@ torch.cuda.synchronize()
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
c = torch.zeros((3,), device="cuda") c = torch.zeros((3,), device="cuda")
@unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
# If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
# DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
# as a memory leak unless we skip the leak check.
@skipCUDAMemoryLeakCheckIf(True)
def test_graph_cudnn_dropout(self):
# Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
# In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
# avoid syncing noncapturing streams with captured events or vice versa.
model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
x = torch.ones(100, 192, 512, device="cuda")
y = model(x)
g = torch.cuda._Graph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
g.capture_begin()
y = model(x)
g.capture_end()
torch.cuda.current_stream().wait_stream(s)
y = model(x)
@unittest.skipIf((not TEST_CUDA) or @unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")