From a3b7dd7b78debd9ef8d74f87b3b296424ce608d4 Mon Sep 17 00:00:00 2001 From: Victor Quach Date: Tue, 11 Jan 2022 15:02:13 -0800 Subject: [PATCH] Enable nested default hooks (#70932) Summary: When default hooks are set, they are pushed onto a stack. When nesting context-manager, only the inner-most hooks will be applied. There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?) Fixes https://github.com/pytorch/pytorch/issues/70134 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932 Reviewed By: mruberry Differential Revision: D33530370 Pulled By: albanD fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4 --- aten/src/ATen/SavedTensorHooks.cpp | 35 +++++++++----- aten/src/ATen/SavedTensorHooks.h | 6 ++- aten/src/ATen/ThreadLocalState.cpp | 7 ++- aten/src/ATen/ThreadLocalState.h | 4 +- test/test_autograd.py | 48 +++++++++++++++++-- torch/_C/_autograd.pyi | 4 +- torch/autograd/__init__.py | 2 +- torch/autograd/graph.py | 19 +++----- torch/csrc/autograd/init.cpp | 8 ++-- .../autograd/python_saved_variable_hooks.cpp | 14 ++---- .../autograd/python_saved_variable_hooks.h | 4 +- 11 files changed, 98 insertions(+), 53 deletions(-) diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp index 62d762bc9d1..aff6ddd1b06 100644 --- a/aten/src/ATen/SavedTensorHooks.cpp +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -1,13 +1,12 @@ #include #include +#include namespace at { namespace { // PyObject is defined in c10/util/python_stub.h - // Reference counting is handled by the caller of `set_hooks`. - thread_local PyObject* pack_hook_(nullptr); - thread_local PyObject* unpack_hook_(nullptr); + thread_local std::stack> stack; // This flag is set to true the first time default hooks are registered // and left at true for the rest of the execution. @@ -20,20 +19,32 @@ void SavedTensorDefaultHooks::enable() { is_enabled = true; } -void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) { - if (!is_enabled) { - TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr); - return; - } - pack_hook_ = pack_hook; - unpack_hook_ = unpack_hook; +void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) { + // Reference counting is handled by the caller of `push_hooks` + TORCH_INTERNAL_ASSERT(is_enabled); + TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); + stack.push(std::make_pair(pack_hook, unpack_hook)); +} + +void SavedTensorDefaultHooks::pop_hooks() { + // Reference counting is handled by the caller of `pop_hooks` + TORCH_INTERNAL_ASSERT(is_enabled && !stack.empty()); + stack.pop(); } std::pair SavedTensorDefaultHooks::get_hooks() { - if (!is_enabled) { + if (!is_enabled || stack.empty()) { return std::make_pair(nullptr, nullptr); } - return std::make_pair(pack_hook_, unpack_hook_); + return stack.top(); +} + +std::stack> SavedTensorDefaultHooks::get_stack() { + return stack; +} + +void SavedTensorDefaultHooks::set_stack(std::stack> stack_) { + stack = stack_; } } diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h index 0f3be924b6d..0cdfa3c9ecc 100644 --- a/aten/src/ATen/SavedTensorHooks.h +++ b/aten/src/ATen/SavedTensorHooks.h @@ -2,15 +2,19 @@ #include #include +#include #include namespace at { struct TORCH_API SavedTensorDefaultHooks { - static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook); + static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook); + static void pop_hooks(); static std::pair get_hooks(); static void enable(); + static std::stack> get_stack(); + static void set_stack(std::stack>); }; } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 42969642526..3e3d4d6a957 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -15,7 +15,8 @@ ThreadLocalState::ThreadLocalState() functorch_tls_(functorch::getCopyOfFuncTorchTLS()), autograd_tls_(c10::AutogradState::get_tls_state()) { rf_tls_ = at::get_record_function_tls_(); - saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); + + saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack(); bumped_record_all_functions_ = at::checkRecordAllFunctions(); python_mode_state_ = at::impl::PythonModeTLS::get_state(); @@ -36,9 +37,7 @@ void ThreadLocalState::setThreadLocalState( at::set_record_function_tls_(state.rf_tls_); - SavedTensorDefaultHooks::set_hooks( - state.saved_tensors_default_hooks_.first, - state.saved_tensors_default_hooks_.second); + at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_); c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index c0e53798722..c5f14518f42 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -54,7 +56,7 @@ class TORCH_API ThreadLocalState { std::shared_ptr python_mode_state_; // TLS for saved tensors default hooks - std::pair saved_tensors_default_hooks_; + std::stack> saved_tensors_default_hooks_; // Whether pre-sampling RecordFunction optimization was enabled bool bumped_record_all_functions_ = false; diff --git a/test/test_autograd.py b/test/test_autograd.py index b921d40ed0d..684eddf3064 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5906,11 +5906,51 @@ for shape in [(1,), ()]: y.sum().backward() self.assertEqual(a.grad, y) - def test_setting_default_saved_variable_hooks_twice_should_fail(self): - with self.assertRaisesRegex(RuntimeError, "Setting default hooks but they have already been set. "): + def test_setting_default_saved_variable_hooks_twice_should_not_fail(self): + with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): - with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): - pass + pass + + def test_setting_default_saved_variable_hooks_twice_should_use_inner(self): + with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x): + b = torch.randn(5, requires_grad=True) + with torch.autograd.graph.saved_tensors_hooks(lambda x: 5 * x, lambda x: 5 * x): + a = torch.randn(5, requires_grad=True) + y = a * a + z = b * b + y.sum().backward() + z.sum().backward() + self.assertEqual(2 * 5 * 5 * a, a.grad) + self.assertEqual(2 * 3 * 3 * b, b.grad) + + def test_save_on_cpu_and_checkpoint(self): + a = torch.randn(2, 2, requires_grad=True) + + b = a.pow(2).pow(2).pow(2).pow(2) + b.sum().backward() + b_grad = a.grad.clone() + a.grad.zero_() + + with torch.autograd.graph.save_on_cpu(): + h = a.pow(2) + h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False) + c = h.pow(2) + c.sum().backward() + c_grad = a.grad.clone() + a.grad.zero_() + + def f(a): + h = a.pow(2) + with torch.autograd.graph.save_on_cpu(): + h = h.pow(2).pow(2) + return h.pow(2) + + d = checkpoint(f, a, use_reentrant=False) + d.sum().backward() + d_grad = a.grad.clone() + + self.assertEqual(b_grad, c_grad) + self.assertEqual(b_grad, d_grad) def test_pack_hook_with_inplace_modification_should_fail(self): a = torch.randn(5, requires_grad=True) diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index e2a6039b281..38ac7ccaea0 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -93,8 +93,8 @@ def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... def _supported_activities() -> Set[ProfilerActivity]: ... def _enable_record_function(enable: bool) -> None: ... def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... -def _register_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... -def _reset_saved_tensors_default_hooks() -> None: ... +def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... +def _pop_saved_tensors_default_hooks() -> None: ... def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 18e808db967..28eb729ffcb 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -307,7 +307,7 @@ from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, Pro _enable_record_function, _set_empty_test_observer, kineto_available, _record_function_with_args_enter, _record_function_with_args_exit, _supported_activities, _add_metadata_json, SavedTensor, - _register_saved_tensors_default_hooks, _reset_saved_tensors_default_hooks) + _push_saved_tensors_default_hooks, _pop_saved_tensors_default_hooks) from torch._C._autograd import (_ProfilerResult, _KinetoEvent, _prepare_profiler, _enable_profiler, _disable_profiler) diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index b0accf011db..f81a42285e0 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -58,21 +58,21 @@ class saved_tensors_hooks(): to undefined behavior. .. warning :: - Only one pair of hooks is allowed at a time. Recursively nesting this - context-manager is not yet supported. + Only one pair of hooks is allowed at a time. When recursively nesting this + context-manager, only the inner-most pair of hooks will be applied. """ def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]): self.pack_hook = pack_hook self.unpack_hook = unpack_hook def __enter__(self): - torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) + torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) def __exit__(self, *args: Any): - torch._C._autograd._reset_saved_tensors_default_hooks() + torch._C._autograd._pop_saved_tensors_default_hooks() -class save_on_cpu(): +class save_on_cpu(saved_tensors_hooks): """Context-manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. @@ -129,11 +129,4 @@ class save_on_cpu(): device, tensor = packed return tensor.to(device, non_blocking=pin_memory) - self.pack_hook = pack_to_cpu - self.unpack_hook = unpack_from_cpu - - def __enter__(self): - torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) - - def __exit__(self, *args: Any): - torch._C._autograd._reset_saved_tensors_default_hooks() + super().__init__(pack_to_cpu, unpack_from_cpu) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 29bd54c5f1d..a17602a1503 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -303,11 +303,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { m.def("_clear_callbacks", []() { at::clearCallbacks(); }); - m.def("_register_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) { - torch::autograd::PyDefaultSavedVariableHooks::set_hooks(pack_hook, unpack_hook); + m.def("_push_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) { + torch::autograd::PyDefaultSavedVariableHooks::push_hooks(pack_hook, unpack_hook); }); - m.def("_reset_saved_tensors_default_hooks", []() { - torch::autograd::PyDefaultSavedVariableHooks::reset_hooks(); + m.def("_pop_saved_tensors_default_hooks", []() { + torch::autograd::PyDefaultSavedVariableHooks::pop_hooks(); }); _C_m.def("_register_py_class_for_device", [](const std::string& device, py::object python_type_class) { diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp index c81a07d1cc5..4d224f2982b 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.cpp +++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp @@ -46,25 +46,21 @@ namespace torch { namespace autograd { } } - void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) { - PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr); - std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks(); - TORCH_CHECK(!pack_hook_ && !unpack_hook_, - "Setting default hooks but they have already been set. " - "Hint: only one pair of hooks is allowed at a time."); + void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) { at::SavedTensorDefaultHooks::enable(); - at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr()); + at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr()); } - void PyDefaultSavedVariableHooks::reset_hooks() { + void PyDefaultSavedVariableHooks::pop_hooks() { PyObject *pack_hook(nullptr), *unpack_hook(nullptr); std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); + TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); if (Py_IsInitialized()) { py::gil_scoped_acquire gil; Py_XDECREF(pack_hook); Py_XDECREF(unpack_hook); } - at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr); + at::SavedTensorDefaultHooks::pop_hooks(); } std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { diff --git a/torch/csrc/autograd/python_saved_variable_hooks.h b/torch/csrc/autograd/python_saved_variable_hooks.h index f8c215f555e..4500bf6b19f 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.h +++ b/torch/csrc/autograd/python_saved_variable_hooks.h @@ -24,8 +24,8 @@ private: }; struct PyDefaultSavedVariableHooks { - static void set_hooks(py::function &pack_hook, py::function &unpack_hook); - static void reset_hooks(); + static void push_hooks(py::function &pack_hook, py::function &unpack_hook); + static void pop_hooks(); static std::unique_ptr get_hooks(); };