Update SavedTensorHooks TLS stack to use SafePyObject (#131700)

Previously, we must manually manage refcounting when updating the TLS saved variable stack. With this PR, things should be handled automatically by the SafePyObject.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131700
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2024-08-02 09:24:49 -04:00 committed by PyTorch MergeBot
parent 9eeb5eebab
commit 82b6480b0a
9 changed files with 113 additions and 29 deletions

View File

@ -2,6 +2,7 @@
#include <c10/util/Exception.h>
#include <stack>
#include <utility>
#include <c10/core/SafePyObject.h>
namespace at {
@ -57,26 +58,23 @@ void SavedTensorDefaultHooks::lazy_initialize() {
is_initialized = true;
}
void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
// Reference counting is handled by the caller of `push_hooks`
void SavedTensorDefaultHooks::push_hooks(SafePyObject pack_hook, SafePyObject unpack_hook) {
TORCH_INTERNAL_ASSERT(is_initialized);
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
assertSavedTensorHooksNotDisabled();
tls.stack.emplace(pack_hook, unpack_hook);
tls.stack.emplace(std::move(pack_hook), std::move(unpack_hook));
}
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::pop_hooks() {
// Reference counting is handled by the caller of `pop_hooks`
std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
std::pair<PyObject*, PyObject*> hooks = tls.stack.top();
std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.top());
tls.stack.pop();
return hooks;
}
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
c10::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
return std::make_pair(nullptr, nullptr);
return c10::nullopt;
}
return tls.stack.top();
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <optional>
@ -14,7 +15,7 @@ namespace impl {
struct TORCH_API SavedTensorDefaultHooksTLS {
// PyObject is defined in c10/util/python_stub.h
std::stack<std::pair<PyObject*, PyObject*>> stack;
std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
// NOTE: [disabled_error_message invariant]
@ -30,9 +31,12 @@ struct TORCH_API SavedTensorDefaultHooksTLS {
} // namespace impl
struct TORCH_API SavedTensorDefaultHooks {
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
static std::pair<PyObject*, PyObject*> pop_hooks();
static std::pair<PyObject*, PyObject*> get_hooks();
static void push_hooks(
c10::SafePyObject pack_hook,
c10::SafePyObject unpack_hook);
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
get_hooks();
static void lazy_initialize();
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();

View File

@ -29,10 +29,27 @@ struct C10_API SafePyObject {
// For now it's not used, so we just disallow it.
SafePyObject& operator=(SafePyObject&&) = delete;
// In principle this could be copyable if we add an incref to PyInterpreter
// but for now it's easier to just disallow it.
SafePyObject(SafePyObject const&) = delete;
SafePyObject& operator=(SafePyObject const&) = delete;
SafePyObject(SafePyObject const& other)
: data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
if (data_ != nullptr) {
(*pyinterpreter_)->incref(data_);
}
}
SafePyObject& operator=(SafePyObject const& other) {
if (this == &other) {
return *this; // Handle self-assignment
}
if (other.data_ != nullptr) {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
return *this;
}
~SafePyObject() {
if (data_ != nullptr) {

View File

@ -9,6 +9,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
return "<unloaded interpreter>";
}
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing

View File

@ -124,6 +124,8 @@ struct C10_API PyInterpreterVTable {
// Report the name of this interpreter
virtual std::string name() const = 0;
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;

View File

@ -2195,7 +2195,7 @@ def wrap_test_class(orig_cls):
dct = orig_cls.__dict__.copy()
for name in list(dct.keys()):
fn = dct[name]
if not callable(fn):
if not callable(fn) or name in skipped_tests:
continue
elif known_failures_re.match(name) or name in known_failing_tests:
dct[name] = unittest.expectedFailure
@ -2231,6 +2231,13 @@ known_failures_re = re.compile(
)
# Bugs needing investigation:
skipped_tests = {
# These test unconventional usage of saved tensor hooks do not leak or crash
# Running these tests succeed, but somehow cause other tests to fail
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash",
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak",
}
known_failing_tests = {
"test_current_graph_task_execution_order", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function <
"test_input_buffer_accum", # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

View File

@ -555,6 +555,49 @@ class TestAutograd(TestCase):
# if forward AD ends up being implemented for torch.igamma, choose a different op
torch.igamma(dual_x, dual_x)
def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self):
# This usage of saved tensor is not supported, but should not crash
def unpack(x):
ctx_1.__exit__()
return x
ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x)
for i in range(10):
with ctx_2:
ctx_1.__enter__()
x = torch.randn(3, 3, requires_grad=True)
x.sin().sum().backward()
# Clean up
for i in range(10):
ctx_1.__exit__()
# Validate there are no more hooks on the stack
a = torch.tensor(1.0, requires_grad=True)
y = a.exp()
y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self):
# This usage of saved tensor is not supported, but should not leak
def scope():
def unpack(x):
weak_ctx_1().__enter__()
return x
ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
weak_ctx_1 = weakref.ref(ctx_1)
x = torch.randn(3, 3, requires_grad=True)
with ctx_1:
x.sin().sum().backward()
return weakref.ref(unpack)
with disable_gc():
unpack_hook_ref = scope()
self.assertIsNone(unpack_hook_ref())
def test_will_engine_execute_node(self):
counter = [0]

View File

@ -44,6 +44,7 @@ struct ConcretePyInterpreterVTable final
: public c10::impl::PyInterpreterVTable {
std::string name() const override;
void incref(PyObject* pyobj) const override;
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
@ -275,6 +276,13 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
Py_DECREF(pyobj);
};
void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
Py_INCREF(pyobj);
};
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}

View File

@ -1,6 +1,8 @@
#include <ATen/SavedTensorHooks.h>
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <c10/core/SafePyObject.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/THP.h>
namespace py = pybind11;
@ -60,27 +62,28 @@ void PyDefaultSavedVariableHooks::push_hooks(
py::function& unpack_hook) {
at::SavedTensorDefaultHooks::lazy_initialize();
at::SavedTensorDefaultHooks::push_hooks(
pack_hook.release().ptr(), unpack_hook.release().ptr());
c10::SafePyObject(pack_hook.release().ptr(), getPyInterpreter()),
c10::SafePyObject(unpack_hook.release().ptr(), getPyInterpreter()));
}
void PyDefaultSavedVariableHooks::pop_hooks() {
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::pop_hooks();
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook);
Py_XDECREF(unpack_hook);
}
TORCH_INTERNAL_ASSERT(
pack_hook.ptr(getPyInterpreter()) != nullptr &&
unpack_hook.ptr(getPyInterpreter()) != nullptr);
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks();
if (!pack_hook || !unpack_hook) {
auto out = at::SavedTensorDefaultHooks::get_hooks();
if (!out.has_value()) {
return nullptr;
}
auto [pack_hook, unpack_hook] = *out;
py::gil_scoped_acquire gil;
py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
py::function pack_hook_ =
py::reinterpret_steal<py::function>(pack_hook.release());
py::function unpack_hook_ =
py::reinterpret_steal<py::function>(unpack_hook.release());
return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
}