mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] fix segfault due to dangling CacheEntry backend pointer (#156527)
Fixes https://github.com/pytorch/pytorch/issues/155057 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156527 Approved by: https://github.com/anijain2305, https://github.com/jansel
This commit is contained in:
parent
e0447bb5f8
commit
6089ebcf6d
|
|
@ -6992,6 +6992,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
res = torch.compile(f, backend="aot_eager")()
|
res = torch.compile(f, backend="aot_eager")()
|
||||||
self.assertEqual(ref, res)
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
|
def test_deleted_compile_wrapper_segfault(self):
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager")
|
||||||
|
# This calls cached_backend.clear() which removes any strong references
|
||||||
|
# to the callback
|
||||||
|
torch._dynamo.reset()
|
||||||
|
opt_fn(torch.randn(3))
|
||||||
|
opt_fn = torch.compile(fn, backend="eager")
|
||||||
|
opt_fn(torch.randn(3)) # possible segfault due to first opt_fn deletion
|
||||||
|
|
||||||
def test_delete_local_error(self):
|
def test_delete_local_error(self):
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
|
||||||
|
|
@ -269,6 +269,7 @@ def _create_delayed_compile_callback(callback, stance):
|
||||||
compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable
|
compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable
|
||||||
return _create_wrapped_callback(compiler_fn)(*args, **kwargs)
|
return _create_wrapped_callback(compiler_fn)(*args, **kwargs)
|
||||||
|
|
||||||
|
callback_fn._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
|
||||||
return callback_fn
|
return callback_fn
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
#include <torch/csrc/dynamo/extra_state.h>
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
|
|
||||||
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
||||||
: backend{backend} {
|
: backend{py::cast<py::object>(get_backend(backend))} {
|
||||||
this->guard_manager = guarded_code.attr("guard_manager");
|
this->guard_manager = guarded_code.attr("guard_manager");
|
||||||
this->code = guarded_code.attr("code");
|
this->code = guarded_code.attr("code");
|
||||||
this->compile_id = guarded_code.attr("compile_id");
|
this->compile_id = guarded_code.attr("compile_id");
|
||||||
|
|
@ -52,6 +52,7 @@ void CacheEntry::invalidate(py::object deleted_guard_manager) {
|
||||||
this->guard_manager = std::move(deleted_guard_manager);
|
this->guard_manager = std::move(deleted_guard_manager);
|
||||||
this->root_mgr = nullptr;
|
this->root_mgr = nullptr;
|
||||||
this->trace_annotation = "Invalidated";
|
this->trace_annotation = "Invalidated";
|
||||||
|
this->backend = py::none();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CacheEntry::update_diff_guard_root_manager() {
|
void CacheEntry::update_diff_guard_root_manager() {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
|
||||||
// diff guard root guard manager if exists
|
// diff guard root guard manager if exists
|
||||||
void* diff_guard_root_mgr{nullptr};
|
void* diff_guard_root_mgr{nullptr};
|
||||||
// backend used to create this cache entry
|
// backend used to create this cache entry
|
||||||
PyObject* backend{nullptr};
|
py::object backend;
|
||||||
// Reference to owning ExtraState
|
// Reference to owning ExtraState
|
||||||
ExtraState* _owner{nullptr};
|
ExtraState* _owner{nullptr};
|
||||||
// Reference to this CacheEntry's location in owner's linked list
|
// Reference to this CacheEntry's location in owner's linked list
|
||||||
|
|
|
||||||
|
|
@ -152,8 +152,8 @@ void lookup(
|
||||||
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
|
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
|
||||||
// Check backend. Py_False means run only mode.
|
// Check backend. Py_False means run only mode.
|
||||||
|
|
||||||
bool valid =
|
bool valid = backend == Py_False ||
|
||||||
backend == Py_False || backend_match(cache_entry.backend, backend);
|
backend_match(cache_entry.backend.ptr(), backend);
|
||||||
|
|
||||||
if (valid) {
|
if (valid) {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user