Add Context Manager for Disabling Multithreading in Backwards, use in aot autograd (#86245)

We were running into a few issues with running multithreaded backwards in aot_autograd: such as https://github.com/pytorch/pytorch/issues/86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86245
Approved by: https://github.com/albanD, https://github.com/yf225
This commit is contained in:
Elias Ellison 2022-10-05 21:25:25 +00:00 committed by PyTorch MergeBot
parent 237316aa1d
commit d04889323e
14 changed files with 126 additions and 10 deletions

View File

@ -27,6 +27,10 @@ void ThreadLocalState::set_grad_mode(bool enabled) {
autograd_tls_.set_grad_mode(enabled);
}
void ThreadLocalState::set_multithreading_enabled(bool enabled) {
autograd_tls_.set_multithreading_enabled(enabled);
}
/* static */
void ThreadLocalState::setThreadLocalState(
const ThreadLocalState& state) {

View File

@ -30,6 +30,12 @@ class TORCH_API ThreadLocalState {
// autograd engine.
void set_grad_mode(bool enabled);
// set_multithreading_enabled - force the value of the multithreadinmaximum
// threads TLS in
// the current state object. This is used for example in the
// autograd engine.
void set_multithreading_enabled(bool enabled);
// Sets thread local variables in the current thread,
// according to the thread boundary specified
static void setThreadLocalState(const ThreadLocalState& state);

View File

@ -3,11 +3,13 @@
namespace c10 {
namespace {
// By default, grad mode is enabled and inference mode is disabled
// By default, grad mode and mulithreading are enabled, inference mode is
// disabled,
thread_local AutogradState autograd_state_tls = AutogradState(
/* grad_mode */ true,
/* inference_mode */ false,
/* fw_grad_mode */ true);
/* fw_grad_mode */ true,
/* multithreading_enabled */ true);
} // namespace
AutogradState& AutogradState::get_tls_state() {

View File

@ -12,10 +12,15 @@ struct C10_API AutogradState {
static AutogradState& get_tls_state();
static void set_tls_state(AutogradState state);
AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
AutogradState(
bool grad_mode,
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode) {}
fw_grad_mode_(fw_grad_mode),
mulithreading_enabled_(multithreading_enabled) {}
void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
@ -29,6 +34,10 @@ struct C10_API AutogradState {
inference_mode_ = enabled;
}
void set_multithreading_enabled(bool mulithreading_enabled) {
mulithreading_enabled_ = mulithreading_enabled;
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -41,10 +50,15 @@ struct C10_API AutogradState {
return inference_mode_;
}
bool get_multithreading_enabled() const {
return mulithreading_enabled_;
}
private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;
bool mulithreading_enabled_ : 1;
};
} // namespace c10

View File

@ -58,7 +58,8 @@ struct TORCH_API InferenceMode {
AutogradState::set_tls_state(AutogradState(
/* grad_mode */ !enabled,
/* inference_mode */ enabled,
/* fw_grad_mode */ !enabled));
/* fw_grad_mode */ !enabled,
/* multithreading_enabled*/ !enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);

View File

@ -630,3 +630,14 @@ Operator Tags
.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.utils.model_dump
.. automodule:: torch.autograd
.. currentmodule:: torch.autograd
Engine Configuration
----------------------------------
.. autosummary::
:toctree: generated
:nosignatures:
set_multithreading_enabled

View File

@ -479,7 +479,7 @@ def create_aot_dispatcher_function(
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
with preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
def process_inputs(flat_args):
if config.use_fake_tensor:

View File

@ -199,6 +199,7 @@
"no_grad",
"set_detect_anomaly",
"set_grad_enabled",
"set_multithreading_enabled",
"variable"
],
"torch.autograd.function": [

View File

@ -9304,6 +9304,33 @@ class TestAutogradMultipleDispatch(TestCase):
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
@onlyCUDA
def test_backward_single_threaded(self):
threads_eq = None
class TestFn(Function):
@staticmethod
def forward(ctx, x, self):
ctx.self = self
ctx.tid = threading.get_ident()
return x.clone()
@staticmethod
def backward(ctx, gO):
nonlocal threads_eq
threads_eq = ctx.tid == threading.get_ident()
return gO, None
inp = torch.rand(10, device="cuda", requires_grad=True)
with torch.autograd.set_multithreading_enabled(False):
TestFn.apply(inp, None).sum().backward()
self.assertTrue(threads_eq)
TestFn.apply(inp, None).sum().backward()
self.assertFalse(threads_eq)
# Import test cases from below autograd/ here. These are found
# implicitly by the loader, so Flake8 thinks they are unused, hence
# the suppressions.

View File

@ -948,6 +948,9 @@ class _DisableFuncTorch:
class _EnableTorchFunction:
def __init__(self) -> None: ...
class _MultithreadingEnabled:
def __init__(self, mode: _bool) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp
class LoggerBase(object):
...

View File

@ -15,7 +15,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from .variable import Variable
from .function import Function, NestedIOFunction
from .gradcheck import gradcheck, gradgradcheck
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
from . import functional

View File

@ -5,7 +5,7 @@ import inspect
from typing import Any, Callable, TypeVar, cast
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
'inference_mode']
'inference_mode', 'set_multithreading_enabled']
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
@ -184,7 +184,7 @@ class enable_grad(_DecoratorContextManager):
class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation to on or off.
r"""Context-manager that sets gradient calculation on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
@ -298,3 +298,35 @@ class inference_mode(_DecoratorContextManager):
def clone(self):
return self.__class__(self.mode)
class set_multithreading_enabled(_DecoratorContextManager):
r"""Context-manager that sets multithreaded backwards on or off.
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
(``False``).
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, mode: bool) -> None:
self.mode = mode
self.multithreadeding_enabled_guard = torch._C._MultithreadingEnabled(mode)
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
del self.multithreadeding_enabled_guard
def clone(self):
return self.__class__(self.mode)

View File

@ -1255,7 +1255,9 @@ void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
auto Engine::ready_queue(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
at::Device device) -> std::shared_ptr<ReadyQueue> {
if (should_run_in_cpu_ready_queue(device.type())) {
bool multithreading_disabled =
!c10::AutogradState::get_tls_state().get_multithreading_enabled();
if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
// return the cpu ready queue passed in
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
return cpu_ready_queue;

View File

@ -43,6 +43,17 @@ struct DisableFuncTorch {
c10::impl::ExcludeDispatchKeyGuard back_guard_;
};
struct MultithreadingEnabled {
MultithreadingEnabled(bool enabled)
: old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled);
}
~MultithreadingEnabled() {
c10::AutogradState::get_tls_state().set_multithreading_enabled(old_);
}
bool old_;
};
struct EnableTorchFunction {
EnableTorchFunction()
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
@ -354,6 +365,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_DisablePythonDispatcher")
.def(py::init<>());
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
.def(py::init<bool>());
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
.def(py::init([]() -> torch::autograd::SavedVariable {