From b3861ac8e70a2334278c17f834880e2d45bbe7a9 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Sat, 1 Nov 2025 12:33:48 +0000 Subject: [PATCH] [reland] Warn if AccumulateGrad stream does not match producer node stream (#166136) ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8 Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065 Fixes #ISSUE_NUMBER Opening a new PR for codev Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136 Approved by: https://github.com/ngimel --- aten/src/ATen/Context.cpp | 8 ++ aten/src/ATen/Context.h | 4 + docs/source/autograd.md | 4 +- test/test_autograd.py | 103 ++++++++++++++---- torch/_C/__init__.pyi.in | 1 + torch/autograd/graph.py | 8 ++ torch/csrc/Module.cpp | 34 ++++++ torch/csrc/autograd/engine.cpp | 15 ++- torch/csrc/autograd/input_buffer.cpp | 20 +++- torch/csrc/autograd/input_buffer.h | 3 +- .../autograd/engine/dist_engine.cpp | 7 +- 11 files changed, 181 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3310abfb41d..facb88c47bd 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -825,6 +825,14 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) { display_vmap_fallback_warnings_ = enabled; } +bool Context::warnOnAccumulateGradStreamMismatch() const { + return warn_on_accumulate_grad_stream_mismatch_; +} + +void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) { + warn_on_accumulate_grad_stream_mismatch_ = enabled; +} + bool Context::isDefaultMobileCPUAllocatorSet() { return prev_allocator_ptr_ != nullptr; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a4a26b5671e..6807e527eb7 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -404,6 +404,9 @@ class TORCH_API Context { void setDisplayVmapFallbackWarnings(bool enabled); bool areVmapFallbackWarningsEnabled() const; + void setWarnOnAccumulateGradStreamMismatch(bool enabled); + bool warnOnAccumulateGradStreamMismatch() const; + bool isDefaultMobileCPUAllocatorSet(); void setDefaultMobileCPUAllocator(); void unsetDefaultMobileCPUAllocator(); @@ -494,6 +497,7 @@ class TORCH_API Context { bool release_original_weights = false; #endif bool display_vmap_fallback_warnings_ = false; + bool warn_on_accumulate_grad_stream_mismatch_ = true; std::atomic quantized_engine = at::QEngine::NoQEngine; bool enable_sparse_tensor_invariant_checks = false; bool allow_fp16_reduction_cpu = false; diff --git a/docs/source/autograd.md b/docs/source/autograd.md index 4218eac05d7..e78b77e4eb4 100644 --- a/docs/source/autograd.md +++ b/docs/source/autograd.md @@ -423,8 +423,10 @@ Also see {ref}`saved-tensors-hooks-doc`. ```{eval-rst} .. autofunction:: torch.autograd.graph.get_gradient_edge +``` - +```{eval-rst} +.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch ``` % This module needs to be documented. Adding here in the meantime diff --git a/test/test_autograd.py b/test/test_autograd.py index 2acdd491e70..6c3e250df7c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -54,6 +54,7 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -72,6 +73,7 @@ from torch.testing._internal.common_utils import ( run_tests, scoped_load_inline, set_warn_always_context, + skipCUDANonDefaultStreamIf, skipIfMPS, skipIfNoLapack, skipIfTorchDynamo, @@ -13325,9 +13327,12 @@ class TestAutogradStreamSynchronization(TestCase): ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") - def test_consumer_to_single_producer_case_2_correctness(self): + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) + def test_consumer_to_single_producer_case_2_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer (MulBackward): cuda:0 s0 # Producer : cuda:0 s1 @@ -13430,36 +13435,43 @@ class TestAutogradStreamSynchronization(TestCase): test() # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream( - self, + self, device ): + if device == "cpu": + self.skipTest("requires accelerator") self._test_consumer_to_single_producer_case_3_correctness( non_default_ambient_stream=True ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_single_producer_case_3_correctness(self): + def test_consumer_to_single_producer_case_3_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") self._test_consumer_to_single_producer_case_3_correctness( non_default_ambient_stream=False ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_single_producer_case_4_correctness(self): + def test_consumer_to_single_producer_case_4_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer: cuda:0 cuda:0 default # Producer: cuda:1 s1 @@ -13516,12 +13528,15 @@ class TestAutogradStreamSynchronization(TestCase): test() # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_multi_producer_case_4_correctness(self): + def test_consumer_to_multi_producer_case_4_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer : cuda:0 cuda:0 default # @@ -13603,12 +13618,11 @@ class TestAutogradStreamSynchronization(TestCase): for _ in range(2): test() - # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS # This test may spuriously fail on non-cuda accelerators (since we won't # be calling sleep) - @unittest.skipIf(not TEST_CUDA, "requires CUDA") - def test_side_stream_backward_overlap(self): + @onlyCUDA + @skipCUDANonDefaultStreamIf(True) + def test_side_stream_backward_overlap(self, device): # In case 2/3, we would designate the consumer as the accumulation # stream and naively, one might have the consumer wait for the producer # as soon as we've added to the InputBuffer the first time. @@ -13709,6 +13723,54 @@ class TestAutogradStreamSynchronization(TestCase): populate_events() check_ordering() + @expectedFailureMPS + def test_warn_on_accumulate_grad_stream_mismatch_flag(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + + def do_test(suppress_warn, keep_grad_acc): + def _test(): + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + + with torch.Stream(0) as s0: + a = torch.ones(8, 8, device=device, requires_grad=True) + if keep_grad_acc: + # create grad_acc under s1 and keep alive with b + b = a.clone() + + with torch.Stream(0) as s1: + s1.wait_stream(s0) + c = a.sum() + + c.backward() + + filter_str = "set_warn_on_accumulate_grad_stream_mismatch" + return sum([filter_str in str(w.message) for w in warns]) > 0 + + if suppress_warn: + try: + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( + False + ) + actual_warn = _test() + finally: + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( + True + ) + else: + actual_warn = _test() + + expect_warn = not suppress_warn and keep_grad_acc + self.assertEqual(actual_warn, expect_warn) + + # Warn by default + self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch()) + + for suppress_warn in (True, False): + for keep_grad_acc in (True, False): + do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc) + class TestMultithreadAutograd(TestCase): def _run_py_multithread_fn( @@ -15196,6 +15258,9 @@ instantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None instantiate_device_type_tests( TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda") ) +instantiate_device_type_tests( + TestAutogradStreamSynchronization, globals(), except_for=None +) instantiate_parametrized_tests(TestAutograd) instantiate_parametrized_tests(TestNestedCheckpoint) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 10178c9fbf4..4acffdb1997 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1308,6 +1308,7 @@ def _group_tensors_by_device_and_dtype( tuple[list[list[Tensor | None]], list[_int]], ]: ... def _initCrashHandler() -> None: ... +def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ... # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 2ade6485fff..f7c7150aa7e 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -44,6 +44,7 @@ __all__ = [ "GradientEdge", "get_gradient_edge", "increment_version", + "set_warn_on_accumulate_grad_stream_mismatch", ] @@ -438,6 +439,13 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) +def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None: + """Whether to warn when the AccumulateGrad node's stream does not match the stream + of the node that produced the incoming gradient. + """ + return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled) + + class _MultiHandle(RemovableHandle): handles: tuple[RemovableHandle, ...] diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 8c5f8e59183..ad37abe3b56 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1605,6 +1605,32 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled( END_HANDLE_TH_ERRORS } +static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "enabled must be a bool, " + "but got ", + THPUtils_typename(arg)); + at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + if (at::globalContext().warnOnAccumulateGradStreamMismatch()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* THCPModule_ensureCUDADeviceGuardSet( PyObject* self, PyObject* noargs) { @@ -1822,6 +1848,14 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, + {"_set_warn_on_accumulate_grad_stream_mismatch", + THPModule_set_warn_on_accumulate_grad_stream_mismatch, + METH_O, + nullptr}, + {"_warn_on_accumulate_grad_stream_mismatch", + THPModule_warn_on_accumulate_grad_stream_mismatch, + METH_NOARGS, + nullptr}, {"_to_dlpack", castPyCFunctionWithKeywords(THPModule_toDLPack), METH_VARARGS | METH_KEYWORDS, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f92af4994fd..0b70aae489e 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1199,7 +1199,11 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); + next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream, + next.function.get()); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); @@ -1215,7 +1219,11 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); + next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream, + next.function.get()); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); queue->push( @@ -1368,7 +1376,8 @@ auto Engine::execute( root_edges.at(0).input_nr, std::move(input), input_stream, - opt_next_stream); + opt_next_stream, + root_edges.at(0).function.get()); execute_with_graph_task( graph_task, std::move(graph_root), std::move(input_buffer)); diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 63ca5daedd2..62770ef9465 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include @@ -191,7 +193,8 @@ void InputBuffer::add( size_t pos, Variable&& var, const std::optional& opt_producer_stream_, - const std::optional& opt_consumer_stream_) { + const std::optional& opt_consumer_stream_, + Node* fn) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); if (!var.defined()) { @@ -231,6 +234,21 @@ void InputBuffer::add( TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream); + if (*opt_consumer_stream != *opt_producer_stream && + dynamic_cast(fn) && + at::globalContext().warnOnAccumulateGradStreamMismatch()) { + TORCH_WARN_ONCE( + "The AccumulateGrad node's stream does not match the stream of the node that produced " + "the incoming gradient. This may incur unnecessary synchronization and break CUDA graph " + "capture if the AccumulateGrad node's stream is the default stream. This mismatch is " + "caused by an AccumulateGrad node created prior to the current iteration being kept alive. " + "This can happen if the autograd graph is still being kept alive by tensors such as the " + "loss, or if you are using DDP, which will stash a reference to the node. To resolve the " + "mismatch, delete all references to the autograd graph or ensure that DDP initialization is " + "performed under the same stream as subsequent forwards. If the mismatch is intentional, " + "you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this " + "warning."); + } // See Note: [Autograd Producer-Consumer Stream Syncs] if (!opt_accum_streams[pos].has_value()) { // [ First producer ] diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index 89abd91f491..791710d2952 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -32,7 +32,8 @@ struct InputBuffer { size_t pos, Variable&& var, const std::optional& opt_producer_stream, - const std::optional& opt_consumer_stream); + const std::optional& opt_consumer_stream, + Node* fn); Variable operator[](size_t pos) { return buffer[pos]; diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 3743476c7a5..156c9efd5ca 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -98,7 +98,12 @@ void DistEngine::globalCpuThread( InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { - inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); + inputs.add( + i, + std::move(variables[i]), + std::nullopt, + std::nullopt, + graphRoot.get()); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),