Enable nested default hooks (#70932)

Summary:
When default hooks are set, they are pushed onto a stack.
When nesting context-manager, only the inner-most hooks will
be applied.

There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?)

Fixes https://github.com/pytorch/pytorch/issues/70134

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932

Reviewed By: mruberry

Differential Revision: D33530370

Pulled By: albanD

fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4
This commit is contained in:
Victor Quach 2022-01-11 15:02:13 -08:00 committed by Facebook GitHub Bot
parent 433cf44b79
commit a3b7dd7b78
11 changed files with 98 additions and 53 deletions

View File

@ -1,13 +1,12 @@
#include <ATen/SavedTensorHooks.h>
#include <c10/util/Exception.h>
#include <stack>
namespace at {
namespace {
// PyObject is defined in c10/util/python_stub.h
// Reference counting is handled by the caller of `set_hooks`.
thread_local PyObject* pack_hook_(nullptr);
thread_local PyObject* unpack_hook_(nullptr);
thread_local std::stack<std::pair<PyObject*, PyObject*>> stack;
// This flag is set to true the first time default hooks are registered
// and left at true for the rest of the execution.
@ -20,20 +19,32 @@ void SavedTensorDefaultHooks::enable() {
is_enabled = true;
}
void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
if (!is_enabled) {
TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr);
return;
}
pack_hook_ = pack_hook;
unpack_hook_ = unpack_hook;
void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
// Reference counting is handled by the caller of `push_hooks`
TORCH_INTERNAL_ASSERT(is_enabled);
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
stack.push(std::make_pair(pack_hook, unpack_hook));
}
void SavedTensorDefaultHooks::pop_hooks() {
// Reference counting is handled by the caller of `pop_hooks`
TORCH_INTERNAL_ASSERT(is_enabled && !stack.empty());
stack.pop();
}
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
if (!is_enabled) {
if (!is_enabled || stack.empty()) {
return std::make_pair(nullptr, nullptr);
}
return std::make_pair(pack_hook_, unpack_hook_);
return stack.top();
}
std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() {
return stack;
}
void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) {
stack = stack_;
}
}

View File

@ -2,15 +2,19 @@
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <stack>
#include <utility>
namespace at {
struct TORCH_API SavedTensorDefaultHooks {
static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook);
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
static void pop_hooks();
static std::pair<PyObject*, PyObject*> get_hooks();
static void enable();
static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
};
} // namespace at

View File

@ -15,7 +15,8 @@ ThreadLocalState::ThreadLocalState()
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
autograd_tls_(c10::AutogradState::get_tls_state()) {
rf_tls_ = at::get_record_function_tls_();
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
bumped_record_all_functions_ = at::checkRecordAllFunctions();
python_mode_state_ = at::impl::PythonModeTLS::get_state();
@ -36,9 +37,7 @@ void ThreadLocalState::setThreadLocalState(
at::set_record_function_tls_(state.rf_tls_);
SavedTensorDefaultHooks::set_hooks(
state.saved_tensors_default_hooks_.first,
state.saved_tensors_default_hooks_.second);
at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);

View File

@ -1,5 +1,7 @@
#pragma once
#include <stack>
#include <c10/core/InferenceMode.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Exception.h>
@ -54,7 +56,7 @@ class TORCH_API ThreadLocalState {
std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
// TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;

View File

@ -5906,11 +5906,51 @@ for shape in [(1,), ()]:
y.sum().backward()
self.assertEqual(a.grad, y)
def test_setting_default_saved_variable_hooks_twice_should_fail(self):
with self.assertRaisesRegex(RuntimeError, "Setting default hooks but they have already been set. "):
def test_setting_default_saved_variable_hooks_twice_should_not_fail(self):
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
pass
pass
def test_setting_default_saved_variable_hooks_twice_should_use_inner(self):
with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x):
b = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(lambda x: 5 * x, lambda x: 5 * x):
a = torch.randn(5, requires_grad=True)
y = a * a
z = b * b
y.sum().backward()
z.sum().backward()
self.assertEqual(2 * 5 * 5 * a, a.grad)
self.assertEqual(2 * 3 * 3 * b, b.grad)
def test_save_on_cpu_and_checkpoint(self):
a = torch.randn(2, 2, requires_grad=True)
b = a.pow(2).pow(2).pow(2).pow(2)
b.sum().backward()
b_grad = a.grad.clone()
a.grad.zero_()
with torch.autograd.graph.save_on_cpu():
h = a.pow(2)
h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False)
c = h.pow(2)
c.sum().backward()
c_grad = a.grad.clone()
a.grad.zero_()
def f(a):
h = a.pow(2)
with torch.autograd.graph.save_on_cpu():
h = h.pow(2).pow(2)
return h.pow(2)
d = checkpoint(f, a, use_reentrant=False)
d.sum().backward()
d_grad = a.grad.clone()
self.assertEqual(b_grad, c_grad)
self.assertEqual(b_grad, d_grad)
def test_pack_hook_with_inplace_modification_should_fail(self):
a = torch.randn(5, requires_grad=True)

View File

@ -93,8 +93,8 @@ def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
def _supported_activities() -> Set[ProfilerActivity]: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _register_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _reset_saved_tensors_default_hooks() -> None: ...
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...

View File

@ -307,7 +307,7 @@ from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, Pro
_enable_record_function, _set_empty_test_observer, kineto_available,
_record_function_with_args_enter, _record_function_with_args_exit,
_supported_activities, _add_metadata_json, SavedTensor,
_register_saved_tensors_default_hooks, _reset_saved_tensors_default_hooks)
_push_saved_tensors_default_hooks, _pop_saved_tensors_default_hooks)
from torch._C._autograd import (_ProfilerResult, _KinetoEvent,
_prepare_profiler, _enable_profiler, _disable_profiler)

View File

@ -58,21 +58,21 @@ class saved_tensors_hooks():
to undefined behavior.
.. warning ::
Only one pair of hooks is allowed at a time. Recursively nesting this
context-manager is not yet supported.
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
"""
def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
self.pack_hook = pack_hook
self.unpack_hook = unpack_hook
def __enter__(self):
torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
def __exit__(self, *args: Any):
torch._C._autograd._reset_saved_tensors_default_hooks()
torch._C._autograd._pop_saved_tensors_default_hooks()
class save_on_cpu():
class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
stored on cpu, then retrieved for backward.
@ -129,11 +129,4 @@ class save_on_cpu():
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
self.pack_hook = pack_to_cpu
self.unpack_hook = unpack_from_cpu
def __enter__(self):
torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
def __exit__(self, *args: Any):
torch._C._autograd._reset_saved_tensors_default_hooks()
super().__init__(pack_to_cpu, unpack_from_cpu)

View File

@ -303,11 +303,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
m.def("_clear_callbacks", []() {
at::clearCallbacks();
});
m.def("_register_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) {
torch::autograd::PyDefaultSavedVariableHooks::set_hooks(pack_hook, unpack_hook);
m.def("_push_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) {
torch::autograd::PyDefaultSavedVariableHooks::push_hooks(pack_hook, unpack_hook);
});
m.def("_reset_saved_tensors_default_hooks", []() {
torch::autograd::PyDefaultSavedVariableHooks::reset_hooks();
m.def("_pop_saved_tensors_default_hooks", []() {
torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
});
_C_m.def("_register_py_class_for_device", [](const std::string& device, py::object python_type_class) {

View File

@ -46,25 +46,21 @@ namespace torch { namespace autograd {
}
}
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
"Setting default hooks but they have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) {
at::SavedTensorDefaultHooks::enable();
at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
}
void PyDefaultSavedVariableHooks::reset_hooks() {
void PyDefaultSavedVariableHooks::pop_hooks() {
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook);
Py_XDECREF(unpack_hook);
}
at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
at::SavedTensorDefaultHooks::pop_hooks();
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {

View File

@ -24,8 +24,8 @@ private:
};
struct PyDefaultSavedVariableHooks {
static void set_hooks(py::function &pack_hook, py::function &unpack_hook);
static void reset_hooks();
static void push_hooks(py::function &pack_hook, py::function &unpack_hook);
static void pop_hooks();
static std::unique_ptr<SavedVariableHooks> get_hooks();
};