mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable nested default hooks (#70932)
Summary: When default hooks are set, they are pushed onto a stack. When nesting context-manager, only the inner-most hooks will be applied. There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?) Fixes https://github.com/pytorch/pytorch/issues/70134 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932 Reviewed By: mruberry Differential Revision: D33530370 Pulled By: albanD fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4
This commit is contained in:
parent
433cf44b79
commit
a3b7dd7b78
|
|
@ -1,13 +1,12 @@
|
|||
#include <ATen/SavedTensorHooks.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <stack>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
// PyObject is defined in c10/util/python_stub.h
|
||||
// Reference counting is handled by the caller of `set_hooks`.
|
||||
thread_local PyObject* pack_hook_(nullptr);
|
||||
thread_local PyObject* unpack_hook_(nullptr);
|
||||
thread_local std::stack<std::pair<PyObject*, PyObject*>> stack;
|
||||
|
||||
// This flag is set to true the first time default hooks are registered
|
||||
// and left at true for the rest of the execution.
|
||||
|
|
@ -20,20 +19,32 @@ void SavedTensorDefaultHooks::enable() {
|
|||
is_enabled = true;
|
||||
}
|
||||
|
||||
void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
|
||||
if (!is_enabled) {
|
||||
TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr);
|
||||
return;
|
||||
void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
|
||||
// Reference counting is handled by the caller of `push_hooks`
|
||||
TORCH_INTERNAL_ASSERT(is_enabled);
|
||||
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
|
||||
stack.push(std::make_pair(pack_hook, unpack_hook));
|
||||
}
|
||||
pack_hook_ = pack_hook;
|
||||
unpack_hook_ = unpack_hook;
|
||||
|
||||
void SavedTensorDefaultHooks::pop_hooks() {
|
||||
// Reference counting is handled by the caller of `pop_hooks`
|
||||
TORCH_INTERNAL_ASSERT(is_enabled && !stack.empty());
|
||||
stack.pop();
|
||||
}
|
||||
|
||||
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
|
||||
if (!is_enabled) {
|
||||
if (!is_enabled || stack.empty()) {
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
}
|
||||
return std::make_pair(pack_hook_, unpack_hook_);
|
||||
return stack.top();
|
||||
}
|
||||
|
||||
std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() {
|
||||
return stack;
|
||||
}
|
||||
|
||||
void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) {
|
||||
stack = stack_;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,15 +2,19 @@
|
|||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <stack>
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API SavedTensorDefaultHooks {
|
||||
static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook);
|
||||
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
|
||||
static void pop_hooks();
|
||||
static std::pair<PyObject*, PyObject*> get_hooks();
|
||||
static void enable();
|
||||
static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
|
||||
static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ ThreadLocalState::ThreadLocalState()
|
|||
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
|
||||
autograd_tls_(c10::AutogradState::get_tls_state()) {
|
||||
rf_tls_ = at::get_record_function_tls_();
|
||||
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
|
||||
|
||||
saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
|
||||
|
||||
bumped_record_all_functions_ = at::checkRecordAllFunctions();
|
||||
python_mode_state_ = at::impl::PythonModeTLS::get_state();
|
||||
|
|
@ -36,9 +37,7 @@ void ThreadLocalState::setThreadLocalState(
|
|||
|
||||
at::set_record_function_tls_(state.rf_tls_);
|
||||
|
||||
SavedTensorDefaultHooks::set_hooks(
|
||||
state.saved_tensors_default_hooks_.first,
|
||||
state.saved_tensors_default_hooks_.second);
|
||||
at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
|
||||
|
||||
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <stack>
|
||||
|
||||
#include <c10/core/InferenceMode.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
|
@ -54,7 +56,7 @@ class TORCH_API ThreadLocalState {
|
|||
std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
|
||||
|
||||
// TLS for saved tensors default hooks
|
||||
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
|
||||
std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
|
||||
|
||||
// Whether pre-sampling RecordFunction optimization was enabled
|
||||
bool bumped_record_all_functions_ = false;
|
||||
|
|
|
|||
|
|
@ -5906,12 +5906,52 @@ for shape in [(1,), ()]:
|
|||
y.sum().backward()
|
||||
self.assertEqual(a.grad, y)
|
||||
|
||||
def test_setting_default_saved_variable_hooks_twice_should_fail(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Setting default hooks but they have already been set. "):
|
||||
def test_setting_default_saved_variable_hooks_twice_should_not_fail(self):
|
||||
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
|
||||
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
|
||||
pass
|
||||
|
||||
def test_setting_default_saved_variable_hooks_twice_should_use_inner(self):
|
||||
with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x):
|
||||
b = torch.randn(5, requires_grad=True)
|
||||
with torch.autograd.graph.saved_tensors_hooks(lambda x: 5 * x, lambda x: 5 * x):
|
||||
a = torch.randn(5, requires_grad=True)
|
||||
y = a * a
|
||||
z = b * b
|
||||
y.sum().backward()
|
||||
z.sum().backward()
|
||||
self.assertEqual(2 * 5 * 5 * a, a.grad)
|
||||
self.assertEqual(2 * 3 * 3 * b, b.grad)
|
||||
|
||||
def test_save_on_cpu_and_checkpoint(self):
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
|
||||
b = a.pow(2).pow(2).pow(2).pow(2)
|
||||
b.sum().backward()
|
||||
b_grad = a.grad.clone()
|
||||
a.grad.zero_()
|
||||
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
h = a.pow(2)
|
||||
h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False)
|
||||
c = h.pow(2)
|
||||
c.sum().backward()
|
||||
c_grad = a.grad.clone()
|
||||
a.grad.zero_()
|
||||
|
||||
def f(a):
|
||||
h = a.pow(2)
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
h = h.pow(2).pow(2)
|
||||
return h.pow(2)
|
||||
|
||||
d = checkpoint(f, a, use_reentrant=False)
|
||||
d.sum().backward()
|
||||
d_grad = a.grad.clone()
|
||||
|
||||
self.assertEqual(b_grad, c_grad)
|
||||
self.assertEqual(b_grad, d_grad)
|
||||
|
||||
def test_pack_hook_with_inplace_modification_should_fail(self):
|
||||
a = torch.randn(5, requires_grad=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
|
|||
def _supported_activities() -> Set[ProfilerActivity]: ...
|
||||
def _enable_record_function(enable: bool) -> None: ...
|
||||
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
||||
def _register_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||
def _reset_saved_tensors_default_hooks() -> None: ...
|
||||
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||
def _pop_saved_tensors_default_hooks() -> None: ...
|
||||
|
||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, Pro
|
|||
_enable_record_function, _set_empty_test_observer, kineto_available,
|
||||
_record_function_with_args_enter, _record_function_with_args_exit,
|
||||
_supported_activities, _add_metadata_json, SavedTensor,
|
||||
_register_saved_tensors_default_hooks, _reset_saved_tensors_default_hooks)
|
||||
_push_saved_tensors_default_hooks, _pop_saved_tensors_default_hooks)
|
||||
|
||||
from torch._C._autograd import (_ProfilerResult, _KinetoEvent,
|
||||
_prepare_profiler, _enable_profiler, _disable_profiler)
|
||||
|
|
|
|||
|
|
@ -58,21 +58,21 @@ class saved_tensors_hooks():
|
|||
to undefined behavior.
|
||||
|
||||
.. warning ::
|
||||
Only one pair of hooks is allowed at a time. Recursively nesting this
|
||||
context-manager is not yet supported.
|
||||
Only one pair of hooks is allowed at a time. When recursively nesting this
|
||||
context-manager, only the inner-most pair of hooks will be applied.
|
||||
"""
|
||||
def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
|
||||
self.pack_hook = pack_hook
|
||||
self.unpack_hook = unpack_hook
|
||||
|
||||
def __enter__(self):
|
||||
torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
|
||||
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
|
||||
|
||||
def __exit__(self, *args: Any):
|
||||
torch._C._autograd._reset_saved_tensors_default_hooks()
|
||||
torch._C._autograd._pop_saved_tensors_default_hooks()
|
||||
|
||||
|
||||
class save_on_cpu():
|
||||
class save_on_cpu(saved_tensors_hooks):
|
||||
"""Context-manager under which tensors saved by the forward pass will be
|
||||
stored on cpu, then retrieved for backward.
|
||||
|
||||
|
|
@ -129,11 +129,4 @@ class save_on_cpu():
|
|||
device, tensor = packed
|
||||
return tensor.to(device, non_blocking=pin_memory)
|
||||
|
||||
self.pack_hook = pack_to_cpu
|
||||
self.unpack_hook = unpack_from_cpu
|
||||
|
||||
def __enter__(self):
|
||||
torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
|
||||
|
||||
def __exit__(self, *args: Any):
|
||||
torch._C._autograd._reset_saved_tensors_default_hooks()
|
||||
super().__init__(pack_to_cpu, unpack_from_cpu)
|
||||
|
|
|
|||
|
|
@ -303,11 +303,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
m.def("_clear_callbacks", []() {
|
||||
at::clearCallbacks();
|
||||
});
|
||||
m.def("_register_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) {
|
||||
torch::autograd::PyDefaultSavedVariableHooks::set_hooks(pack_hook, unpack_hook);
|
||||
m.def("_push_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) {
|
||||
torch::autograd::PyDefaultSavedVariableHooks::push_hooks(pack_hook, unpack_hook);
|
||||
});
|
||||
m.def("_reset_saved_tensors_default_hooks", []() {
|
||||
torch::autograd::PyDefaultSavedVariableHooks::reset_hooks();
|
||||
m.def("_pop_saved_tensors_default_hooks", []() {
|
||||
torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
|
||||
});
|
||||
|
||||
_C_m.def("_register_py_class_for_device", [](const std::string& device, py::object python_type_class) {
|
||||
|
|
|
|||
|
|
@ -46,25 +46,21 @@ namespace torch { namespace autograd {
|
|||
}
|
||||
}
|
||||
|
||||
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
||||
PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
|
||||
std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
|
||||
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
|
||||
"Setting default hooks but they have already been set. "
|
||||
"Hint: only one pair of hooks is allowed at a time.");
|
||||
void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
||||
at::SavedTensorDefaultHooks::enable();
|
||||
at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
|
||||
at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
|
||||
}
|
||||
|
||||
void PyDefaultSavedVariableHooks::reset_hooks() {
|
||||
void PyDefaultSavedVariableHooks::pop_hooks() {
|
||||
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
|
||||
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
|
||||
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
|
||||
if (Py_IsInitialized()) {
|
||||
py::gil_scoped_acquire gil;
|
||||
Py_XDECREF(pack_hook);
|
||||
Py_XDECREF(unpack_hook);
|
||||
}
|
||||
at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
|
||||
at::SavedTensorDefaultHooks::pop_hooks();
|
||||
}
|
||||
|
||||
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ private:
|
|||
};
|
||||
|
||||
struct PyDefaultSavedVariableHooks {
|
||||
static void set_hooks(py::function &pack_hook, py::function &unpack_hook);
|
||||
static void reset_hooks();
|
||||
static void push_hooks(py::function &pack_hook, py::function &unpack_hook);
|
||||
static void pop_hooks();
|
||||
static std::unique_ptr<SavedVariableHooks> get_hooks();
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user