mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Context Manager for Disabling Multithreading in Backwards, use in aot autograd (#86245)
We were running into a few issues with running multithreaded backwards in aot_autograd: such as https://github.com/pytorch/pytorch/issues/86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86245 Approved by: https://github.com/albanD, https://github.com/yf225
This commit is contained in:
parent
237316aa1d
commit
d04889323e
|
|
@ -27,6 +27,10 @@ void ThreadLocalState::set_grad_mode(bool enabled) {
|
|||
autograd_tls_.set_grad_mode(enabled);
|
||||
}
|
||||
|
||||
void ThreadLocalState::set_multithreading_enabled(bool enabled) {
|
||||
autograd_tls_.set_multithreading_enabled(enabled);
|
||||
}
|
||||
|
||||
/* static */
|
||||
void ThreadLocalState::setThreadLocalState(
|
||||
const ThreadLocalState& state) {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ class TORCH_API ThreadLocalState {
|
|||
// autograd engine.
|
||||
void set_grad_mode(bool enabled);
|
||||
|
||||
// set_multithreading_enabled - force the value of the multithreadinmaximum
|
||||
// threads TLS in
|
||||
// the current state object. This is used for example in the
|
||||
// autograd engine.
|
||||
void set_multithreading_enabled(bool enabled);
|
||||
|
||||
// Sets thread local variables in the current thread,
|
||||
// according to the thread boundary specified
|
||||
static void setThreadLocalState(const ThreadLocalState& state);
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@
|
|||
namespace c10 {
|
||||
|
||||
namespace {
|
||||
// By default, grad mode is enabled and inference mode is disabled
|
||||
// By default, grad mode and mulithreading are enabled, inference mode is
|
||||
// disabled,
|
||||
thread_local AutogradState autograd_state_tls = AutogradState(
|
||||
/* grad_mode */ true,
|
||||
/* inference_mode */ false,
|
||||
/* fw_grad_mode */ true);
|
||||
/* fw_grad_mode */ true,
|
||||
/* multithreading_enabled */ true);
|
||||
} // namespace
|
||||
|
||||
AutogradState& AutogradState::get_tls_state() {
|
||||
|
|
|
|||
|
|
@ -12,10 +12,15 @@ struct C10_API AutogradState {
|
|||
static AutogradState& get_tls_state();
|
||||
static void set_tls_state(AutogradState state);
|
||||
|
||||
AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
|
||||
AutogradState(
|
||||
bool grad_mode,
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode) {}
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
mulithreading_enabled_(multithreading_enabled) {}
|
||||
|
||||
void set_grad_mode(bool enabled) {
|
||||
grad_mode_ = enabled;
|
||||
|
|
@ -29,6 +34,10 @@ struct C10_API AutogradState {
|
|||
inference_mode_ = enabled;
|
||||
}
|
||||
|
||||
void set_multithreading_enabled(bool mulithreading_enabled) {
|
||||
mulithreading_enabled_ = mulithreading_enabled;
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
|
|
@ -41,10 +50,15 @@ struct C10_API AutogradState {
|
|||
return inference_mode_;
|
||||
}
|
||||
|
||||
bool get_multithreading_enabled() const {
|
||||
return mulithreading_enabled_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
bool mulithreading_enabled_ : 1;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -58,7 +58,8 @@ struct TORCH_API InferenceMode {
|
|||
AutogradState::set_tls_state(AutogradState(
|
||||
/* grad_mode */ !enabled,
|
||||
/* inference_mode */ enabled,
|
||||
/* fw_grad_mode */ !enabled));
|
||||
/* fw_grad_mode */ !enabled,
|
||||
/* multithreading_enabled*/ !enabled));
|
||||
DispatchKeySet included = enabled
|
||||
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
|
||||
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
|
||||
|
|
|
|||
|
|
@ -630,3 +630,14 @@ Operator Tags
|
|||
.. This module needs to be documented. Adding here in the meantime
|
||||
.. for tracking purposes
|
||||
.. py:module:: torch.utils.model_dump
|
||||
|
||||
.. automodule:: torch.autograd
|
||||
.. currentmodule:: torch.autograd
|
||||
|
||||
Engine Configuration
|
||||
----------------------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
set_multithreading_enabled
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ def create_aot_dispatcher_function(
|
|||
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
|
||||
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
|
||||
|
||||
with preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
|
||||
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
|
||||
|
||||
def process_inputs(flat_args):
|
||||
if config.use_fake_tensor:
|
||||
|
|
|
|||
|
|
@ -199,6 +199,7 @@
|
|||
"no_grad",
|
||||
"set_detect_anomaly",
|
||||
"set_grad_enabled",
|
||||
"set_multithreading_enabled",
|
||||
"variable"
|
||||
],
|
||||
"torch.autograd.function": [
|
||||
|
|
|
|||
|
|
@ -9304,6 +9304,33 @@ class TestAutogradMultipleDispatch(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
|
||||
foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
|
||||
|
||||
@onlyCUDA
|
||||
def test_backward_single_threaded(self):
|
||||
|
||||
threads_eq = None
|
||||
|
||||
class TestFn(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, self):
|
||||
ctx.self = self
|
||||
ctx.tid = threading.get_ident()
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
nonlocal threads_eq
|
||||
threads_eq = ctx.tid == threading.get_ident()
|
||||
return gO, None
|
||||
|
||||
inp = torch.rand(10, device="cuda", requires_grad=True)
|
||||
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
TestFn.apply(inp, None).sum().backward()
|
||||
self.assertTrue(threads_eq)
|
||||
|
||||
TestFn.apply(inp, None).sum().backward()
|
||||
self.assertFalse(threads_eq)
|
||||
|
||||
# Import test cases from below autograd/ here. These are found
|
||||
# implicitly by the loader, so Flake8 thinks they are unused, hence
|
||||
# the suppressions.
|
||||
|
|
|
|||
|
|
@ -948,6 +948,9 @@ class _DisableFuncTorch:
|
|||
class _EnableTorchFunction:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class _MultithreadingEnabled:
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class LoggerBase(object):
|
||||
...
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
|||
from .variable import Variable
|
||||
from .function import Function, NestedIOFunction
|
||||
from .gradcheck import gradcheck, gradgradcheck
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
|
||||
from . import functional
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import inspect
|
|||
from typing import Any, Callable, TypeVar, cast
|
||||
|
||||
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
|
||||
'inference_mode']
|
||||
'inference_mode', 'set_multithreading_enabled']
|
||||
|
||||
|
||||
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
|
||||
|
|
@ -184,7 +184,7 @@ class enable_grad(_DecoratorContextManager):
|
|||
|
||||
|
||||
class set_grad_enabled(_DecoratorContextManager):
|
||||
r"""Context-manager that sets gradient calculation to on or off.
|
||||
r"""Context-manager that sets gradient calculation on or off.
|
||||
|
||||
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
|
||||
It can be used as a context-manager or as a function.
|
||||
|
|
@ -298,3 +298,35 @@ class inference_mode(_DecoratorContextManager):
|
|||
|
||||
def clone(self):
|
||||
return self.__class__(self.mode)
|
||||
|
||||
|
||||
class set_multithreading_enabled(_DecoratorContextManager):
|
||||
r"""Context-manager that sets multithreaded backwards on or off.
|
||||
|
||||
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
|
||||
It can be used as a context-manager or as a function.
|
||||
|
||||
This context manager is thread local; it will not affect computation
|
||||
in other threads.
|
||||
|
||||
Args:
|
||||
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
|
||||
(``False``).
|
||||
|
||||
.. note::
|
||||
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mode: bool) -> None:
|
||||
self.mode = mode
|
||||
self.multithreadeding_enabled_guard = torch._C._MultithreadingEnabled(mode)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
del self.multithreadeding_enabled_guard
|
||||
|
||||
def clone(self):
|
||||
return self.__class__(self.mode)
|
||||
|
|
|
|||
|
|
@ -1255,7 +1255,9 @@ void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
|
|||
auto Engine::ready_queue(
|
||||
std::shared_ptr<ReadyQueue> cpu_ready_queue,
|
||||
at::Device device) -> std::shared_ptr<ReadyQueue> {
|
||||
if (should_run_in_cpu_ready_queue(device.type())) {
|
||||
bool multithreading_disabled =
|
||||
!c10::AutogradState::get_tls_state().get_multithreading_enabled();
|
||||
if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
|
||||
// return the cpu ready queue passed in
|
||||
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
|
||||
return cpu_ready_queue;
|
||||
|
|
|
|||
|
|
@ -43,6 +43,17 @@ struct DisableFuncTorch {
|
|||
c10::impl::ExcludeDispatchKeyGuard back_guard_;
|
||||
};
|
||||
|
||||
struct MultithreadingEnabled {
|
||||
MultithreadingEnabled(bool enabled)
|
||||
: old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
|
||||
c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled);
|
||||
}
|
||||
~MultithreadingEnabled() {
|
||||
c10::AutogradState::get_tls_state().set_multithreading_enabled(old_);
|
||||
}
|
||||
bool old_;
|
||||
};
|
||||
|
||||
struct EnableTorchFunction {
|
||||
EnableTorchFunction()
|
||||
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
|
||||
|
|
@ -354,6 +365,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||
_C_m, "_DisablePythonDispatcher")
|
||||
.def(py::init<>());
|
||||
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
|
||||
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
|
||||
.def(py::init<bool>());
|
||||
|
||||
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
|
||||
.def(py::init([]() -> torch::autograd::SavedVariable {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user