mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[reland] Warn if AccumulateGrad stream does not match producer node stream (#166136)
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8 Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065 Fixes #ISSUE_NUMBER Opening a new PR for codev Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136 Approved by: https://github.com/ngimel
This commit is contained in:
parent
4cc64d6234
commit
b3861ac8e7
|
|
@ -825,6 +825,14 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
|
||||||
display_vmap_fallback_warnings_ = enabled;
|
display_vmap_fallback_warnings_ = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Context::warnOnAccumulateGradStreamMismatch() const {
|
||||||
|
return warn_on_accumulate_grad_stream_mismatch_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
|
||||||
|
warn_on_accumulate_grad_stream_mismatch_ = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
bool Context::isDefaultMobileCPUAllocatorSet() {
|
bool Context::isDefaultMobileCPUAllocatorSet() {
|
||||||
return prev_allocator_ptr_ != nullptr;
|
return prev_allocator_ptr_ != nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -404,6 +404,9 @@ class TORCH_API Context {
|
||||||
void setDisplayVmapFallbackWarnings(bool enabled);
|
void setDisplayVmapFallbackWarnings(bool enabled);
|
||||||
bool areVmapFallbackWarningsEnabled() const;
|
bool areVmapFallbackWarningsEnabled() const;
|
||||||
|
|
||||||
|
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
|
||||||
|
bool warnOnAccumulateGradStreamMismatch() const;
|
||||||
|
|
||||||
bool isDefaultMobileCPUAllocatorSet();
|
bool isDefaultMobileCPUAllocatorSet();
|
||||||
void setDefaultMobileCPUAllocator();
|
void setDefaultMobileCPUAllocator();
|
||||||
void unsetDefaultMobileCPUAllocator();
|
void unsetDefaultMobileCPUAllocator();
|
||||||
|
|
@ -494,6 +497,7 @@ class TORCH_API Context {
|
||||||
bool release_original_weights = false;
|
bool release_original_weights = false;
|
||||||
#endif
|
#endif
|
||||||
bool display_vmap_fallback_warnings_ = false;
|
bool display_vmap_fallback_warnings_ = false;
|
||||||
|
bool warn_on_accumulate_grad_stream_mismatch_ = true;
|
||||||
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
||||||
bool enable_sparse_tensor_invariant_checks = false;
|
bool enable_sparse_tensor_invariant_checks = false;
|
||||||
bool allow_fp16_reduction_cpu = false;
|
bool allow_fp16_reduction_cpu = false;
|
||||||
|
|
|
||||||
|
|
@ -423,8 +423,10 @@ Also see {ref}`saved-tensors-hooks-doc`.
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autofunction:: torch.autograd.graph.get_gradient_edge
|
.. autofunction:: torch.autograd.graph.get_gradient_edge
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch
|
||||||
```
|
```
|
||||||
|
|
||||||
% This module needs to be documented. Adding here in the meantime
|
% This module needs to be documented. Adding here in the meantime
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
dtypesIfCUDA,
|
dtypesIfCUDA,
|
||||||
dtypesIfMPS,
|
dtypesIfMPS,
|
||||||
|
expectedFailureMPS,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
onlyCPU,
|
onlyCPU,
|
||||||
onlyCUDA,
|
onlyCUDA,
|
||||||
|
|
@ -72,6 +73,7 @@ from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
scoped_load_inline,
|
scoped_load_inline,
|
||||||
set_warn_always_context,
|
set_warn_always_context,
|
||||||
|
skipCUDANonDefaultStreamIf,
|
||||||
skipIfMPS,
|
skipIfMPS,
|
||||||
skipIfNoLapack,
|
skipIfNoLapack,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
|
|
@ -13325,9 +13327,12 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@expectedFailureMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
def test_consumer_to_single_producer_case_2_correctness(self):
|
def test_consumer_to_single_producer_case_2_correctness(self, device):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
|
|
||||||
# Device Stream
|
# Device Stream
|
||||||
# Consumer (MulBackward): cuda:0 s0
|
# Consumer (MulBackward): cuda:0 s0
|
||||||
# Producer : cuda:0 s1
|
# Producer : cuda:0 s1
|
||||||
|
|
@ -13430,36 +13435,43 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
test()
|
test()
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@expectedFailureMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||||
)
|
)
|
||||||
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
||||||
self,
|
self, device
|
||||||
):
|
):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
self._test_consumer_to_single_producer_case_3_correctness(
|
self._test_consumer_to_single_producer_case_3_correctness(
|
||||||
non_default_ambient_stream=True
|
non_default_ambient_stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@expectedFailureMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||||
)
|
)
|
||||||
def test_consumer_to_single_producer_case_3_correctness(self):
|
def test_consumer_to_single_producer_case_3_correctness(self, device):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
self._test_consumer_to_single_producer_case_3_correctness(
|
self._test_consumer_to_single_producer_case_3_correctness(
|
||||||
non_default_ambient_stream=False
|
non_default_ambient_stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@expectedFailureMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||||
)
|
)
|
||||||
def test_consumer_to_single_producer_case_4_correctness(self):
|
def test_consumer_to_single_producer_case_4_correctness(self, device):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
|
|
||||||
# Device Stream
|
# Device Stream
|
||||||
# Consumer: cuda:0 cuda:0 default
|
# Consumer: cuda:0 cuda:0 default
|
||||||
# Producer: cuda:1 s1
|
# Producer: cuda:1 s1
|
||||||
|
|
@ -13516,12 +13528,15 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
test()
|
test()
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@expectedFailureMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||||
)
|
)
|
||||||
def test_consumer_to_multi_producer_case_4_correctness(self):
|
def test_consumer_to_multi_producer_case_4_correctness(self, device):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
|
|
||||||
# Device Stream
|
# Device Stream
|
||||||
# Consumer : cuda:0 cuda:0 default
|
# Consumer : cuda:0 cuda:0 default
|
||||||
#
|
#
|
||||||
|
|
@ -13603,12 +13618,11 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
test()
|
test()
|
||||||
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
|
||||||
@skipIfMPS
|
|
||||||
# This test may spuriously fail on non-cuda accelerators (since we won't
|
# This test may spuriously fail on non-cuda accelerators (since we won't
|
||||||
# be calling sleep)
|
# be calling sleep)
|
||||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
@onlyCUDA
|
||||||
def test_side_stream_backward_overlap(self):
|
@skipCUDANonDefaultStreamIf(True)
|
||||||
|
def test_side_stream_backward_overlap(self, device):
|
||||||
# In case 2/3, we would designate the consumer as the accumulation
|
# In case 2/3, we would designate the consumer as the accumulation
|
||||||
# stream and naively, one might have the consumer wait for the producer
|
# stream and naively, one might have the consumer wait for the producer
|
||||||
# as soon as we've added to the InputBuffer the first time.
|
# as soon as we've added to the InputBuffer the first time.
|
||||||
|
|
@ -13709,6 +13723,54 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
populate_events()
|
populate_events()
|
||||||
check_ordering()
|
check_ordering()
|
||||||
|
|
||||||
|
@expectedFailureMPS
|
||||||
|
def test_warn_on_accumulate_grad_stream_mismatch_flag(self, device):
|
||||||
|
if device == "cpu":
|
||||||
|
self.skipTest("requires accelerator")
|
||||||
|
|
||||||
|
def do_test(suppress_warn, keep_grad_acc):
|
||||||
|
def _test():
|
||||||
|
with warnings.catch_warnings(record=True) as warns:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
|
with torch.Stream(0) as s0:
|
||||||
|
a = torch.ones(8, 8, device=device, requires_grad=True)
|
||||||
|
if keep_grad_acc:
|
||||||
|
# create grad_acc under s1 and keep alive with b
|
||||||
|
b = a.clone()
|
||||||
|
|
||||||
|
with torch.Stream(0) as s1:
|
||||||
|
s1.wait_stream(s0)
|
||||||
|
c = a.sum()
|
||||||
|
|
||||||
|
c.backward()
|
||||||
|
|
||||||
|
filter_str = "set_warn_on_accumulate_grad_stream_mismatch"
|
||||||
|
return sum([filter_str in str(w.message) for w in warns]) > 0
|
||||||
|
|
||||||
|
if suppress_warn:
|
||||||
|
try:
|
||||||
|
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||||
|
False
|
||||||
|
)
|
||||||
|
actual_warn = _test()
|
||||||
|
finally:
|
||||||
|
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(
|
||||||
|
True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
actual_warn = _test()
|
||||||
|
|
||||||
|
expect_warn = not suppress_warn and keep_grad_acc
|
||||||
|
self.assertEqual(actual_warn, expect_warn)
|
||||||
|
|
||||||
|
# Warn by default
|
||||||
|
self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch())
|
||||||
|
|
||||||
|
for suppress_warn in (True, False):
|
||||||
|
for keep_grad_acc in (True, False):
|
||||||
|
do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc)
|
||||||
|
|
||||||
|
|
||||||
class TestMultithreadAutograd(TestCase):
|
class TestMultithreadAutograd(TestCase):
|
||||||
def _run_py_multithread_fn(
|
def _run_py_multithread_fn(
|
||||||
|
|
@ -15196,6 +15258,9 @@ instantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda")
|
TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda")
|
||||||
)
|
)
|
||||||
|
instantiate_device_type_tests(
|
||||||
|
TestAutogradStreamSynchronization, globals(), except_for=None
|
||||||
|
)
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestAutograd)
|
instantiate_parametrized_tests(TestAutograd)
|
||||||
instantiate_parametrized_tests(TestNestedCheckpoint)
|
instantiate_parametrized_tests(TestNestedCheckpoint)
|
||||||
|
|
|
||||||
|
|
@ -1308,6 +1308,7 @@ def _group_tensors_by_device_and_dtype(
|
||||||
tuple[list[list[Tensor | None]], list[_int]],
|
tuple[list[list[Tensor | None]], list[_int]],
|
||||||
]: ...
|
]: ...
|
||||||
def _initCrashHandler() -> None: ...
|
def _initCrashHandler() -> None: ...
|
||||||
|
def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ...
|
||||||
|
|
||||||
# NB: There is no Capsule type in typing, see
|
# NB: There is no Capsule type in typing, see
|
||||||
# https://github.com/python/cpython/issues/109562
|
# https://github.com/python/cpython/issues/109562
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ __all__ = [
|
||||||
"GradientEdge",
|
"GradientEdge",
|
||||||
"get_gradient_edge",
|
"get_gradient_edge",
|
||||||
"increment_version",
|
"increment_version",
|
||||||
|
"set_warn_on_accumulate_grad_stream_mismatch",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -438,6 +439,13 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non
|
||||||
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
|
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
|
||||||
|
|
||||||
|
|
||||||
|
def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None:
|
||||||
|
"""Whether to warn when the AccumulateGrad node's stream does not match the stream
|
||||||
|
of the node that produced the incoming gradient.
|
||||||
|
"""
|
||||||
|
return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled)
|
||||||
|
|
||||||
|
|
||||||
class _MultiHandle(RemovableHandle):
|
class _MultiHandle(RemovableHandle):
|
||||||
handles: tuple[RemovableHandle, ...]
|
handles: tuple[RemovableHandle, ...]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1605,6 +1605,32 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* arg) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
TORCH_CHECK(
|
||||||
|
PyBool_Check(arg),
|
||||||
|
"enabled must be a bool, "
|
||||||
|
"but got ",
|
||||||
|
THPUtils_typename(arg));
|
||||||
|
at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* noargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
if (at::globalContext().warnOnAccumulateGradStreamMismatch()) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
} else {
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
||||||
PyObject* self,
|
PyObject* self,
|
||||||
PyObject* noargs) {
|
PyObject* noargs) {
|
||||||
|
|
@ -1822,6 +1848,14 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||||
THPModule_are_vmap_fallback_warnings_enabled,
|
THPModule_are_vmap_fallback_warnings_enabled,
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_set_warn_on_accumulate_grad_stream_mismatch",
|
||||||
|
THPModule_set_warn_on_accumulate_grad_stream_mismatch,
|
||||||
|
METH_O,
|
||||||
|
nullptr},
|
||||||
|
{"_warn_on_accumulate_grad_stream_mismatch",
|
||||||
|
THPModule_warn_on_accumulate_grad_stream_mismatch,
|
||||||
|
METH_NOARGS,
|
||||||
|
nullptr},
|
||||||
{"_to_dlpack",
|
{"_to_dlpack",
|
||||||
castPyCFunctionWithKeywords(THPModule_toDLPack),
|
castPyCFunctionWithKeywords(THPModule_toDLPack),
|
||||||
METH_VARARGS | METH_KEYWORDS,
|
METH_VARARGS | METH_KEYWORDS,
|
||||||
|
|
|
||||||
|
|
@ -1199,7 +1199,11 @@ void Engine::evaluate_function(
|
||||||
// Accumulates into buffer
|
// Accumulates into buffer
|
||||||
auto opt_next_stream = next.function->stream();
|
auto opt_next_stream = next.function->stream();
|
||||||
input_buffer.add(
|
input_buffer.add(
|
||||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
next.input_nr,
|
||||||
|
std::move(output),
|
||||||
|
opt_parent_stream,
|
||||||
|
opt_next_stream,
|
||||||
|
next.function.get());
|
||||||
|
|
||||||
if (is_ready) {
|
if (is_ready) {
|
||||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||||
|
|
@ -1215,7 +1219,11 @@ void Engine::evaluate_function(
|
||||||
// Accumulates into buffer
|
// Accumulates into buffer
|
||||||
auto opt_next_stream = next.function->stream();
|
auto opt_next_stream = next.function->stream();
|
||||||
input_buffer.add(
|
input_buffer.add(
|
||||||
next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
|
next.input_nr,
|
||||||
|
std::move(output),
|
||||||
|
opt_parent_stream,
|
||||||
|
opt_next_stream,
|
||||||
|
next.function.get());
|
||||||
if (is_ready) {
|
if (is_ready) {
|
||||||
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
auto queue = ready_queue(cpu_ready_queue, next.function->device());
|
||||||
queue->push(
|
queue->push(
|
||||||
|
|
@ -1368,7 +1376,8 @@ auto Engine::execute(
|
||||||
root_edges.at(0).input_nr,
|
root_edges.at(0).input_nr,
|
||||||
std::move(input),
|
std::move(input),
|
||||||
input_stream,
|
input_stream,
|
||||||
opt_next_stream);
|
opt_next_stream,
|
||||||
|
root_edges.at(0).function.get());
|
||||||
|
|
||||||
execute_with_graph_task(
|
execute_with_graph_task(
|
||||||
graph_task, std::move(graph_root), std::move(input_buffer));
|
graph_task, std::move(graph_root), std::move(input_buffer));
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||||
#include <torch/csrc/autograd/input_buffer.h>
|
#include <torch/csrc/autograd/input_buffer.h>
|
||||||
|
|
||||||
#include <ATen/CachedTensorUtils.h>
|
#include <ATen/CachedTensorUtils.h>
|
||||||
|
|
@ -11,6 +12,7 @@
|
||||||
#include <c10/core/DeviceGuard.h>
|
#include <c10/core/DeviceGuard.h>
|
||||||
#include <c10/core/Event.h>
|
#include <c10/core/Event.h>
|
||||||
#include <c10/core/StreamGuard.h>
|
#include <c10/core/StreamGuard.h>
|
||||||
|
#include <c10/util/Logging.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
@ -191,7 +193,8 @@ void InputBuffer::add(
|
||||||
size_t pos,
|
size_t pos,
|
||||||
Variable&& var,
|
Variable&& var,
|
||||||
const std::optional<c10::Stream>& opt_producer_stream_,
|
const std::optional<c10::Stream>& opt_producer_stream_,
|
||||||
const std::optional<c10::Stream>& opt_consumer_stream_) {
|
const std::optional<c10::Stream>& opt_consumer_stream_,
|
||||||
|
Node* fn) {
|
||||||
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
||||||
|
|
||||||
if (!var.defined()) {
|
if (!var.defined()) {
|
||||||
|
|
@ -231,6 +234,21 @@ void InputBuffer::add(
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
||||||
|
|
||||||
|
if (*opt_consumer_stream != *opt_producer_stream &&
|
||||||
|
dynamic_cast<AccumulateGrad*>(fn) &&
|
||||||
|
at::globalContext().warnOnAccumulateGradStreamMismatch()) {
|
||||||
|
TORCH_WARN_ONCE(
|
||||||
|
"The AccumulateGrad node's stream does not match the stream of the node that produced "
|
||||||
|
"the incoming gradient. This may incur unnecessary synchronization and break CUDA graph "
|
||||||
|
"capture if the AccumulateGrad node's stream is the default stream. This mismatch is "
|
||||||
|
"caused by an AccumulateGrad node created prior to the current iteration being kept alive. "
|
||||||
|
"This can happen if the autograd graph is still being kept alive by tensors such as the "
|
||||||
|
"loss, or if you are using DDP, which will stash a reference to the node. To resolve the "
|
||||||
|
"mismatch, delete all references to the autograd graph or ensure that DDP initialization is "
|
||||||
|
"performed under the same stream as subsequent forwards. If the mismatch is intentional, "
|
||||||
|
"you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this "
|
||||||
|
"warning.");
|
||||||
|
}
|
||||||
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
||||||
if (!opt_accum_streams[pos].has_value()) {
|
if (!opt_accum_streams[pos].has_value()) {
|
||||||
// [ First producer ]
|
// [ First producer ]
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,8 @@ struct InputBuffer {
|
||||||
size_t pos,
|
size_t pos,
|
||||||
Variable&& var,
|
Variable&& var,
|
||||||
const std::optional<c10::Stream>& opt_producer_stream,
|
const std::optional<c10::Stream>& opt_producer_stream,
|
||||||
const std::optional<c10::Stream>& opt_consumer_stream);
|
const std::optional<c10::Stream>& opt_consumer_stream,
|
||||||
|
Node* fn);
|
||||||
|
|
||||||
Variable operator[](size_t pos) {
|
Variable operator[](size_t pos) {
|
||||||
return buffer[pos];
|
return buffer[pos];
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,12 @@ void DistEngine::globalCpuThread(
|
||||||
InputBuffer::variables(std::move(task.inputs_))]() mutable {
|
InputBuffer::variables(std::move(task.inputs_))]() mutable {
|
||||||
InputBuffer inputs(variables.size());
|
InputBuffer inputs(variables.size());
|
||||||
for (const auto i : c10::irange(variables.size())) {
|
for (const auto i : c10::irange(variables.size())) {
|
||||||
inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt);
|
inputs.add(
|
||||||
|
i,
|
||||||
|
std::move(variables[i]),
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
graphRoot.get());
|
||||||
}
|
}
|
||||||
execute_graph_task_until_ready_queue_empty(
|
execute_graph_task_until_ready_queue_empty(
|
||||||
/*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
|
/*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user