mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update SavedTensorHooks TLS stack to use SafePyObject (#131700)
Previously, we must manually manage refcounting when updating the TLS saved variable stack. With this PR, things should be handled automatically by the SafePyObject. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131700 Approved by: https://github.com/albanD
This commit is contained in:
parent
9eeb5eebab
commit
82b6480b0a
|
|
@ -2,6 +2,7 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
|
|
@ -57,26 +58,23 @@ void SavedTensorDefaultHooks::lazy_initialize() {
|
|||
is_initialized = true;
|
||||
}
|
||||
|
||||
void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
|
||||
// Reference counting is handled by the caller of `push_hooks`
|
||||
void SavedTensorDefaultHooks::push_hooks(SafePyObject pack_hook, SafePyObject unpack_hook) {
|
||||
TORCH_INTERNAL_ASSERT(is_initialized);
|
||||
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
|
||||
assertSavedTensorHooksNotDisabled();
|
||||
tls.stack.emplace(pack_hook, unpack_hook);
|
||||
tls.stack.emplace(std::move(pack_hook), std::move(unpack_hook));
|
||||
}
|
||||
|
||||
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::pop_hooks() {
|
||||
// Reference counting is handled by the caller of `pop_hooks`
|
||||
std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
|
||||
TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
|
||||
std::pair<PyObject*, PyObject*> hooks = tls.stack.top();
|
||||
std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.top());
|
||||
tls.stack.pop();
|
||||
return hooks;
|
||||
}
|
||||
|
||||
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
|
||||
c10::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
|
||||
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
||||
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
return c10::nullopt;
|
||||
}
|
||||
return tls.stack.top();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <optional>
|
||||
|
|
@ -14,7 +15,7 @@ namespace impl {
|
|||
|
||||
struct TORCH_API SavedTensorDefaultHooksTLS {
|
||||
// PyObject is defined in c10/util/python_stub.h
|
||||
std::stack<std::pair<PyObject*, PyObject*>> stack;
|
||||
std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
|
||||
|
||||
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
|
||||
// NOTE: [disabled_error_message invariant]
|
||||
|
|
@ -30,9 +31,12 @@ struct TORCH_API SavedTensorDefaultHooksTLS {
|
|||
} // namespace impl
|
||||
|
||||
struct TORCH_API SavedTensorDefaultHooks {
|
||||
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
|
||||
static std::pair<PyObject*, PyObject*> pop_hooks();
|
||||
static std::pair<PyObject*, PyObject*> get_hooks();
|
||||
static void push_hooks(
|
||||
c10::SafePyObject pack_hook,
|
||||
c10::SafePyObject unpack_hook);
|
||||
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
|
||||
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
get_hooks();
|
||||
static void lazy_initialize();
|
||||
|
||||
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
|
||||
|
|
|
|||
|
|
@ -29,10 +29,27 @@ struct C10_API SafePyObject {
|
|||
// For now it's not used, so we just disallow it.
|
||||
SafePyObject& operator=(SafePyObject&&) = delete;
|
||||
|
||||
// In principle this could be copyable if we add an incref to PyInterpreter
|
||||
// but for now it's easier to just disallow it.
|
||||
SafePyObject(SafePyObject const&) = delete;
|
||||
SafePyObject& operator=(SafePyObject const&) = delete;
|
||||
SafePyObject(SafePyObject const& other)
|
||||
: data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->incref(data_);
|
||||
}
|
||||
}
|
||||
|
||||
SafePyObject& operator=(SafePyObject const& other) {
|
||||
if (this == &other) {
|
||||
return *this; // Handle self-assignment
|
||||
}
|
||||
if (other.data_ != nullptr) {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||
return "<unloaded interpreter>";
|
||||
}
|
||||
|
||||
void incref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,8 @@ struct C10_API PyInterpreterVTable {
|
|||
// Report the name of this interpreter
|
||||
virtual std::string name() const = 0;
|
||||
|
||||
// Run Py_INCREF on a PyObject.
|
||||
virtual void incref(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
|
|
|
|||
|
|
@ -2195,7 +2195,7 @@ def wrap_test_class(orig_cls):
|
|||
dct = orig_cls.__dict__.copy()
|
||||
for name in list(dct.keys()):
|
||||
fn = dct[name]
|
||||
if not callable(fn):
|
||||
if not callable(fn) or name in skipped_tests:
|
||||
continue
|
||||
elif known_failures_re.match(name) or name in known_failing_tests:
|
||||
dct[name] = unittest.expectedFailure
|
||||
|
|
@ -2231,6 +2231,13 @@ known_failures_re = re.compile(
|
|||
)
|
||||
|
||||
# Bugs needing investigation:
|
||||
skipped_tests = {
|
||||
# These test unconventional usage of saved tensor hooks do not leak or crash
|
||||
# Running these tests succeed, but somehow cause other tests to fail
|
||||
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash",
|
||||
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak",
|
||||
}
|
||||
|
||||
known_failing_tests = {
|
||||
"test_current_graph_task_execution_order", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function <
|
||||
"test_input_buffer_accum", # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
|
||||
|
|
|
|||
|
|
@ -555,6 +555,49 @@ class TestAutograd(TestCase):
|
|||
# if forward AD ends up being implemented for torch.igamma, choose a different op
|
||||
torch.igamma(dual_x, dual_x)
|
||||
|
||||
def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self):
|
||||
# This usage of saved tensor is not supported, but should not crash
|
||||
def unpack(x):
|
||||
ctx_1.__exit__()
|
||||
return x
|
||||
|
||||
ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
|
||||
ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
for i in range(10):
|
||||
with ctx_2:
|
||||
ctx_1.__enter__()
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
x.sin().sum().backward()
|
||||
|
||||
# Clean up
|
||||
for i in range(10):
|
||||
ctx_1.__exit__()
|
||||
|
||||
# Validate there are no more hooks on the stack
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
y = a.exp()
|
||||
y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self):
|
||||
# This usage of saved tensor is not supported, but should not leak
|
||||
def scope():
|
||||
def unpack(x):
|
||||
weak_ctx_1().__enter__()
|
||||
return x
|
||||
|
||||
ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
|
||||
weak_ctx_1 = weakref.ref(ctx_1)
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
with ctx_1:
|
||||
x.sin().sum().backward()
|
||||
return weakref.ref(unpack)
|
||||
|
||||
with disable_gc():
|
||||
unpack_hook_ref = scope()
|
||||
self.assertIsNone(unpack_hook_ref())
|
||||
|
||||
def test_will_engine_execute_node(self):
|
||||
counter = [0]
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ struct ConcretePyInterpreterVTable final
|
|||
: public c10::impl::PyInterpreterVTable {
|
||||
std::string name() const override;
|
||||
|
||||
void incref(PyObject* pyobj) const override;
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
|
||||
|
||||
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
||||
|
|
@ -275,6 +276,13 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
|||
Py_DECREF(pyobj);
|
||||
};
|
||||
|
||||
void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
|
||||
if (!Py_IsInitialized())
|
||||
return;
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
Py_INCREF(pyobj);
|
||||
};
|
||||
|
||||
bool isPythonTensor(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#include <ATen/SavedTensorHooks.h>
|
||||
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <torch/csrc/PyInterpreter.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
|
@ -60,27 +62,28 @@ void PyDefaultSavedVariableHooks::push_hooks(
|
|||
py::function& unpack_hook) {
|
||||
at::SavedTensorDefaultHooks::lazy_initialize();
|
||||
at::SavedTensorDefaultHooks::push_hooks(
|
||||
pack_hook.release().ptr(), unpack_hook.release().ptr());
|
||||
c10::SafePyObject(pack_hook.release().ptr(), getPyInterpreter()),
|
||||
c10::SafePyObject(unpack_hook.release().ptr(), getPyInterpreter()));
|
||||
}
|
||||
|
||||
void PyDefaultSavedVariableHooks::pop_hooks() {
|
||||
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::pop_hooks();
|
||||
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
|
||||
if (Py_IsInitialized()) {
|
||||
py::gil_scoped_acquire gil;
|
||||
Py_XDECREF(pack_hook);
|
||||
Py_XDECREF(unpack_hook);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pack_hook.ptr(getPyInterpreter()) != nullptr &&
|
||||
unpack_hook.ptr(getPyInterpreter()) != nullptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
|
||||
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks();
|
||||
if (!pack_hook || !unpack_hook) {
|
||||
auto out = at::SavedTensorDefaultHooks::get_hooks();
|
||||
if (!out.has_value()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto [pack_hook, unpack_hook] = *out;
|
||||
py::gil_scoped_acquire gil;
|
||||
py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
|
||||
py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
|
||||
py::function pack_hook_ =
|
||||
py::reinterpret_steal<py::function>(pack_hook.release());
|
||||
py::function unpack_hook_ =
|
||||
py::reinterpret_steal<py::function>(unpack_hook.release());
|
||||
return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user