[reland] Warn if AccumulateGrad stream does not match producer node stream (#166136)
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled

ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065

Fixes #ISSUE_NUMBER

Opening a new PR for codev

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136
Approved by: https://github.com/ngimel
This commit is contained in:
soulitzer 2025-11-01 12:33:48 +00:00 committed by PyTorch MergeBot
parent 4cc64d6234
commit b3861ac8e7
11 changed files with 181 additions and 26 deletions

View File

@ -825,6 +825,14 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
display_vmap_fallback_warnings_ = enabled; display_vmap_fallback_warnings_ = enabled;
} }
bool Context::warnOnAccumulateGradStreamMismatch() const {
return warn_on_accumulate_grad_stream_mismatch_;
}
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
warn_on_accumulate_grad_stream_mismatch_ = enabled;
}
bool Context::isDefaultMobileCPUAllocatorSet() { bool Context::isDefaultMobileCPUAllocatorSet() {
return prev_allocator_ptr_ != nullptr; return prev_allocator_ptr_ != nullptr;
} }

View File

@ -404,6 +404,9 @@ class TORCH_API Context {
void setDisplayVmapFallbackWarnings(bool enabled); void setDisplayVmapFallbackWarnings(bool enabled);
bool areVmapFallbackWarningsEnabled() const; bool areVmapFallbackWarningsEnabled() const;
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
bool warnOnAccumulateGradStreamMismatch() const;
bool isDefaultMobileCPUAllocatorSet(); bool isDefaultMobileCPUAllocatorSet();
void setDefaultMobileCPUAllocator(); void setDefaultMobileCPUAllocator();
void unsetDefaultMobileCPUAllocator(); void unsetDefaultMobileCPUAllocator();
@ -494,6 +497,7 @@ class TORCH_API Context {
bool release_original_weights = false; bool release_original_weights = false;
#endif #endif
bool display_vmap_fallback_warnings_ = false; bool display_vmap_fallback_warnings_ = false;
bool warn_on_accumulate_grad_stream_mismatch_ = true;
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine; std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
bool enable_sparse_tensor_invariant_checks = false; bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false; bool allow_fp16_reduction_cpu = false;

View File

@ -423,8 +423,10 @@ Also see {ref}`saved-tensors-hooks-doc`.
```{eval-rst} ```{eval-rst}
.. autofunction:: torch.autograd.graph.get_gradient_edge .. autofunction:: torch.autograd.graph.get_gradient_edge
```
```{eval-rst}
.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch
``` ```
% This module needs to be documented. Adding here in the meantime % This module needs to be documented. Adding here in the meantime

View File

@ -54,6 +54,7 @@ from torch.testing._internal.common_device_type import (
dtypes, dtypes,
dtypesIfCUDA, dtypesIfCUDA,
dtypesIfMPS, dtypesIfMPS,
expectedFailureMPS,
instantiate_device_type_tests, instantiate_device_type_tests,
onlyCPU, onlyCPU,
onlyCUDA, onlyCUDA,
@ -72,6 +73,7 @@ from torch.testing._internal.common_utils import (
run_tests, run_tests,
scoped_load_inline, scoped_load_inline,
set_warn_always_context, set_warn_always_context,
skipCUDANonDefaultStreamIf,
skipIfMPS, skipIfMPS,
skipIfNoLapack, skipIfNoLapack,
skipIfTorchDynamo, skipIfTorchDynamo,
@ -13325,9 +13327,12 @@ class TestAutogradStreamSynchronization(TestCase):
) )
# AttributeError: module 'torch.mps' has no attribute 'default_stream' # AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS @expectedFailureMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @skipCUDANonDefaultStreamIf(True)
def test_consumer_to_single_producer_case_2_correctness(self): def test_consumer_to_single_producer_case_2_correctness(self, device):
if device == "cpu":
self.skipTest("requires accelerator")
# Device Stream # Device Stream
# Consumer (MulBackward): cuda:0 s0 # Consumer (MulBackward): cuda:0 s0
# Producer : cuda:0 s1 # Producer : cuda:0 s1
@ -13430,36 +13435,43 @@ class TestAutogradStreamSynchronization(TestCase):
test() test()
# AttributeError: module 'torch.mps' has no attribute 'default_stream' # AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS @expectedFailureMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf( @unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2" torch.accelerator.device_count() < 2, "accelerator count is less than 2"
) )
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream( def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
self, self, device
): ):
if device == "cpu":
self.skipTest("requires accelerator")
self._test_consumer_to_single_producer_case_3_correctness( self._test_consumer_to_single_producer_case_3_correctness(
non_default_ambient_stream=True non_default_ambient_stream=True
) )
# AttributeError: module 'torch.mps' has no attribute 'default_stream' # AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS @expectedFailureMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf( @unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2" torch.accelerator.device_count() < 2, "accelerator count is less than 2"
) )
def test_consumer_to_single_producer_case_3_correctness(self): def test_consumer_to_single_producer_case_3_correctness(self, device):
if device == "cpu":
self.skipTest("requires accelerator")
self._test_consumer_to_single_producer_case_3_correctness( self._test_consumer_to_single_producer_case_3_correctness(
non_default_ambient_stream=False non_default_ambient_stream=False
) )
# AttributeError: module 'torch.mps' has no attribute 'default_stream' # AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS @expectedFailureMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf( @unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2" torch.accelerator.device_count() < 2, "accelerator count is less than 2"
) )
def test_consumer_to_single_producer_case_4_correctness(self): def test_consumer_to_single_producer_case_4_correctness(self, device):
if device == "cpu":
self.skipTest("requires accelerator")
# Device Stream # Device Stream
# Consumer: cuda:0 cuda:0 default # Consumer: cuda:0 cuda:0 default
# Producer: cuda:1 s1 # Producer: cuda:1 s1
@ -13516,12 +13528,15 @@ class TestAutogradStreamSynchronization(TestCase):
test() test()
# AttributeError: module 'torch.mps' has no attribute 'default_stream' # AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS @expectedFailureMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @skipCUDANonDefaultStreamIf(True)
@unittest.skipIf( @unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2" torch.accelerator.device_count() < 2, "accelerator count is less than 2"
) )
def test_consumer_to_multi_producer_case_4_correctness(self): def test_consumer_to_multi_producer_case_4_correctness(self, device):
if device == "cpu":
self.skipTest("requires accelerator")
# Device Stream # Device Stream
# Consumer : cuda:0 cuda:0 default # Consumer : cuda:0 cuda:0 default
# #
@ -13603,12 +13618,11 @@ class TestAutogradStreamSynchronization(TestCase):
for _ in range(2): for _ in range(2):
test() test()
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
# This test may spuriously fail on non-cuda accelerators (since we won't # This test may spuriously fail on non-cuda accelerators (since we won't
# be calling sleep) # be calling sleep)
@unittest.skipIf(not TEST_CUDA, "requires CUDA") @onlyCUDA
def test_side_stream_backward_overlap(self): @skipCUDANonDefaultStreamIf(True)
def test_side_stream_backward_overlap(self, device):
# In case 2/3, we would designate the consumer as the accumulation # In case 2/3, we would designate the consumer as the accumulation
# stream and naively, one might have the consumer wait for the producer # stream and naively, one might have the consumer wait for the producer
# as soon as we've added to the InputBuffer the first time. # as soon as we've added to the InputBuffer the first time.
@ -13709,6 +13723,54 @@ class TestAutogradStreamSynchronization(TestCase):
populate_events() populate_events()
check_ordering() check_ordering()
@expectedFailureMPS
def test_warn_on_accumulate_grad_stream_mismatch_flag(self, device):
if device == "cpu":
self.skipTest("requires accelerator")
def do_test(suppress_warn, keep_grad_acc):
def _test():
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
with torch.Stream(0) as s0:
a = torch.ones(8, 8, device=device, requires_grad=True)
if keep_grad_acc:
# create grad_acc under s1 and keep alive with b
b = a.clone()
with torch.Stream(0) as s1:
s1.wait_stream(s0)
c = a.sum()
c.backward()
filter_str = "set_warn_on_accumulate_grad_stream_mismatch"
return sum([filter_str in str(w.message) for w in warns]) > 0
if suppress_warn:
try:
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
False
)
actual_warn = _test()
finally:
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
True
)
else:
actual_warn = _test()
expect_warn = not suppress_warn and keep_grad_acc
self.assertEqual(actual_warn, expect_warn)
# Warn by default
self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch())
for suppress_warn in (True, False):
for keep_grad_acc in (True, False):
do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc)
class TestMultithreadAutograd(TestCase): class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn( def _run_py_multithread_fn(
@ -15196,6 +15258,9 @@ instantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None
instantiate_device_type_tests( instantiate_device_type_tests(
TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda") TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda")
) )
instantiate_device_type_tests(
TestAutogradStreamSynchronization, globals(), except_for=None
)
instantiate_parametrized_tests(TestAutograd) instantiate_parametrized_tests(TestAutograd)
instantiate_parametrized_tests(TestNestedCheckpoint) instantiate_parametrized_tests(TestNestedCheckpoint)

View File

@ -1308,6 +1308,7 @@ def _group_tensors_by_device_and_dtype(
tuple[list[list[Tensor | None]], list[_int]], tuple[list[list[Tensor | None]], list[_int]],
]: ... ]: ...
def _initCrashHandler() -> None: ... def _initCrashHandler() -> None: ...
def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ...
# NB: There is no Capsule type in typing, see # NB: There is no Capsule type in typing, see
# https://github.com/python/cpython/issues/109562 # https://github.com/python/cpython/issues/109562

View File

@ -44,6 +44,7 @@ __all__ = [
"GradientEdge", "GradientEdge",
"get_gradient_edge", "get_gradient_edge",
"increment_version", "increment_version",
"set_warn_on_accumulate_grad_stream_mismatch",
] ]
@ -438,6 +439,13 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None:
"""Whether to warn when the AccumulateGrad node's stream does not match the stream
of the node that produced the incoming gradient.
"""
return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled)
class _MultiHandle(RemovableHandle): class _MultiHandle(RemovableHandle):
handles: tuple[RemovableHandle, ...] handles: tuple[RemovableHandle, ...]

View File

@ -1605,6 +1605,32 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"enabled must be a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::globalContext().warnOnAccumulateGradStreamMismatch()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* THCPModule_ensureCUDADeviceGuardSet( static PyObject* THCPModule_ensureCUDADeviceGuardSet(
PyObject* self, PyObject* self,
PyObject* noargs) { PyObject* noargs) {
@ -1822,6 +1848,14 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
THPModule_are_vmap_fallback_warnings_enabled, THPModule_are_vmap_fallback_warnings_enabled,
METH_NOARGS, METH_NOARGS,
nullptr}, nullptr},
{"_set_warn_on_accumulate_grad_stream_mismatch",
THPModule_set_warn_on_accumulate_grad_stream_mismatch,
METH_O,
nullptr},
{"_warn_on_accumulate_grad_stream_mismatch",
THPModule_warn_on_accumulate_grad_stream_mismatch,
METH_NOARGS,
nullptr},
{"_to_dlpack", {"_to_dlpack",
castPyCFunctionWithKeywords(THPModule_toDLPack), castPyCFunctionWithKeywords(THPModule_toDLPack),
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,

View File

@ -1199,7 +1199,11 @@ void Engine::evaluate_function(
// Accumulates into buffer // Accumulates into buffer
auto opt_next_stream = next.function->stream(); auto opt_next_stream = next.function->stream();
input_buffer.add( input_buffer.add(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream,
next.function.get());
if (is_ready) { if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, next.function->device()); auto queue = ready_queue(cpu_ready_queue, next.function->device());
@ -1215,7 +1219,11 @@ void Engine::evaluate_function(
// Accumulates into buffer // Accumulates into buffer
auto opt_next_stream = next.function->stream(); auto opt_next_stream = next.function->stream();
input_buffer.add( input_buffer.add(
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream,
next.function.get());
if (is_ready) { if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, next.function->device()); auto queue = ready_queue(cpu_ready_queue, next.function->device());
queue->push( queue->push(
@ -1368,7 +1376,8 @@ auto Engine::execute(
root_edges.at(0).input_nr, root_edges.at(0).input_nr,
std::move(input), std::move(input),
input_stream, input_stream,
opt_next_stream); opt_next_stream,
root_edges.at(0).function.get());
execute_with_graph_task( execute_with_graph_task(
graph_task, std::move(graph_root), std::move(input_buffer)); graph_task, std::move(graph_root), std::move(input_buffer));

View File

@ -1,3 +1,4 @@
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/input_buffer.h> #include <torch/csrc/autograd/input_buffer.h>
#include <ATen/CachedTensorUtils.h> #include <ATen/CachedTensorUtils.h>
@ -11,6 +12,7 @@
#include <c10/core/DeviceGuard.h> #include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h> #include <c10/core/Event.h>
#include <c10/core/StreamGuard.h> #include <c10/core/StreamGuard.h>
#include <c10/util/Logging.h>
#include <optional> #include <optional>
#include <cstddef> #include <cstddef>
@ -191,7 +193,8 @@ void InputBuffer::add(
size_t pos, size_t pos,
Variable&& var, Variable&& var,
const std::optional<c10::Stream>& opt_producer_stream_, const std::optional<c10::Stream>& opt_producer_stream_,
const std::optional<c10::Stream>& opt_consumer_stream_) { const std::optional<c10::Stream>& opt_consumer_stream_,
Node* fn) {
TORCH_INTERNAL_ASSERT(pos < buffer.size()); TORCH_INTERNAL_ASSERT(pos < buffer.size());
if (!var.defined()) { if (!var.defined()) {
@ -231,6 +234,21 @@ void InputBuffer::add(
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream); TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
if (*opt_consumer_stream != *opt_producer_stream &&
dynamic_cast<AccumulateGrad*>(fn) &&
at::globalContext().warnOnAccumulateGradStreamMismatch()) {
TORCH_WARN_ONCE(
"The AccumulateGrad node's stream does not match the stream of the node that produced "
"the incoming gradient. This may incur unnecessary synchronization and break CUDA graph "
"capture if the AccumulateGrad node's stream is the default stream. This mismatch is "
"caused by an AccumulateGrad node created prior to the current iteration being kept alive. "
"This can happen if the autograd graph is still being kept alive by tensors such as the "
"loss, or if you are using DDP, which will stash a reference to the node. To resolve the "
"mismatch, delete all references to the autograd graph or ensure that DDP initialization is "
"performed under the same stream as subsequent forwards. If the mismatch is intentional, "
"you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this "
"warning.");
}
// See Note: [Autograd Producer-Consumer Stream Syncs] // See Note: [Autograd Producer-Consumer Stream Syncs]
if (!opt_accum_streams[pos].has_value()) { if (!opt_accum_streams[pos].has_value()) {
// [ First producer ] // [ First producer ]

View File

@ -32,7 +32,8 @@ struct InputBuffer {
size_t pos, size_t pos,
Variable&& var, Variable&& var,
const std::optional<c10::Stream>& opt_producer_stream, const std::optional<c10::Stream>& opt_producer_stream,
const std::optional<c10::Stream>& opt_consumer_stream); const std::optional<c10::Stream>& opt_consumer_stream,
Node* fn);
Variable operator[](size_t pos) { Variable operator[](size_t pos) {
return buffer[pos]; return buffer[pos];

View File

@ -98,7 +98,12 @@ void DistEngine::globalCpuThread(
InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer::variables(std::move(task.inputs_))]() mutable {
InputBuffer inputs(variables.size()); InputBuffer inputs(variables.size());
for (const auto i : c10::irange(variables.size())) { for (const auto i : c10::irange(variables.size())) {
inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); inputs.add(
i,
std::move(variables[i]),
std::nullopt,
std::nullopt,
graphRoot.get());
} }
execute_graph_task_until_ready_queue_empty( execute_graph_task_until_ready_queue_empty(
/*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)), /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),