mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[RELAND] [CUDA graphs] Avoid sync errors when graph capturing cudnn rnn calls that use cudnn dropout (#57373)
Summary: https://github.com/pytorch/pytorch/pull/56433 was reverted because the test perceived internal dropout state creation as a memory leak. This PR resubmits with the leak check skipped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57373 Reviewed By: anjali411 Differential Revision: D28152186 Pulled By: ezyang fbshipit-source-id: 9a593fcdbbabbb09dc4e4221191663e94b697503
This commit is contained in:
parent
1b745efbe8
commit
e841f335aa
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/Config.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/MatrixRef.h>
|
||||
|
|
@ -1373,6 +1374,30 @@ std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor&
|
|||
return std::make_tuple(hx, cx);
|
||||
}
|
||||
|
||||
/**
|
||||
* Note [DropoutState and CUDA graph capture]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* (1) Telling a capturing stream to wait on an event recorded in a non-capturing stream is an error.
|
||||
* (2) Telling a non-capturing stream to wait on an event recorded during capture is also an error.
|
||||
*
|
||||
* So DropoutState's usage syncs could error if an RNN with dropout is called in an uncaptured region
|
||||
* then called in a captured region (triggering 1), or called in a captured region then called
|
||||
# in an uncaptured region (triggering 2).
|
||||
*
|
||||
* To prevent 1 and 2, lock() only syncs on the last usage event if it was recorded in the same
|
||||
* capture state as the current state (which also means the same graph, if capture is in progress).
|
||||
*
|
||||
* The solution should be safe as long as capture obeys the following restrictions:
|
||||
* - Only one capture may be underway at a time in a given process.
|
||||
* - While a capture is underway, no calls to eager ops on noncapturing streams (on any thread)
|
||||
* may interleave with the captured ops.
|
||||
*
|
||||
* TODO: As people experiment with capture, keep an eye out for use cases that might need to
|
||||
* relax those restrictions.
|
||||
*
|
||||
* See https://github.com/pytorch/pytorch/pull/56433 for more discussion.
|
||||
*/
|
||||
|
||||
struct DropoutState {
|
||||
// Both buffer and event are lazily instantiated when a dropout state is needed
|
||||
// for the first time. Note that in this case needed != used, as we don't need
|
||||
|
|
@ -1380,6 +1405,12 @@ struct DropoutState {
|
|||
at::Tensor buffer;
|
||||
c10::optional<cuda::CUDAEvent> event;
|
||||
std::mutex mutex;
|
||||
#if CUDA_VERSION >= 11000
|
||||
// cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can serve
|
||||
// as a sentinel value that capture was not underway.
|
||||
cuda::CaptureId_t capture_id_last_lock = 0;
|
||||
cuda::CaptureId_t capture_id_last_unlock = 0;
|
||||
#endif
|
||||
|
||||
// Every time we use a dropout state, we need to synchronize with its event,
|
||||
// to make sure all previous uses finish running before this one starts. Once
|
||||
|
|
@ -1392,13 +1423,38 @@ struct DropoutState {
|
|||
// could then define it before we get to unlock().
|
||||
mutex.lock();
|
||||
if (event) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
// See Note [DropoutState and CUDA graph capture]
|
||||
cudaStreamCaptureStatus status;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
|
||||
&status,
|
||||
&capture_id_last_lock));
|
||||
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
|
||||
capture_id_last_lock = 0;
|
||||
}
|
||||
if (capture_id_last_lock == capture_id_last_unlock) {
|
||||
event->block(cuda::getCurrentCUDAStream());
|
||||
}
|
||||
#else
|
||||
event->block(cuda::getCurrentCUDAStream());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void unlock() {
|
||||
if (event) {
|
||||
event->record();
|
||||
#if CUDA_VERSION >= 11000
|
||||
// See Note [DropoutState and CUDA graph capture]
|
||||
cudaStreamCaptureStatus status;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
|
||||
&status,
|
||||
&capture_id_last_unlock));
|
||||
if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
|
||||
capture_id_last_unlock = 0;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock);
|
||||
#endif
|
||||
}
|
||||
mutex.unlock();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, t
|
|||
_compare_trilu_indices, _compare_large_trilu_indices
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
||||
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
|
||||
slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY
|
||||
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
|
|
@ -3457,6 +3457,33 @@ torch.cuda.synchronize()
|
|||
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
|
||||
c = torch.zeros((3,), device="cuda")
|
||||
|
||||
@unittest.skipIf((not TEST_CUDA) or
|
||||
TEST_WITH_ROCM or
|
||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||
# If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
|
||||
# DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
|
||||
# as a memory leak unless we skip the leak check.
|
||||
@skipCUDAMemoryLeakCheckIf(True)
|
||||
def test_graph_cudnn_dropout(self):
|
||||
# Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
|
||||
# In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
|
||||
# avoid syncing noncapturing streams with captured events or vice versa.
|
||||
model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
|
||||
x = torch.ones(100, 192, 512, device="cuda")
|
||||
|
||||
y = model(x)
|
||||
|
||||
g = torch.cuda._Graph()
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
g.capture_begin()
|
||||
y = model(x)
|
||||
g.capture_end()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
y = model(x)
|
||||
|
||||
@unittest.skipIf((not TEST_CUDA) or
|
||||
TEST_WITH_ROCM or
|
||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user