[dynamo] fix segfault due to dangling CacheEntry backend pointer (#156527)

Fixes https://github.com/pytorch/pytorch/issues/155057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156527
Approved by: https://github.com/anijain2305, https://github.com/jansel
This commit is contained in:
William Wen 2025-06-26 11:04:41 -07:00 committed by PyTorch MergeBot
parent e0447bb5f8
commit 6089ebcf6d
5 changed files with 18 additions and 4 deletions

View File

@ -6992,6 +6992,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
res = torch.compile(f, backend="aot_eager")()
self.assertEqual(ref, res)
def test_deleted_compile_wrapper_segfault(self):
def fn(x):
return x + 1
opt_fn = torch.compile(fn, backend="eager")
# This calls cached_backend.clear() which removes any strong references
# to the callback
torch._dynamo.reset()
opt_fn(torch.randn(3))
opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(3)) # possible segfault due to first opt_fn deletion
def test_delete_local_error(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):

View File

@ -269,6 +269,7 @@ def _create_delayed_compile_callback(callback, stance):
compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable
return _create_wrapped_callback(compiler_fn)(*args, **kwargs)
callback_fn._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return callback_fn

View File

@ -5,7 +5,7 @@
#include <torch/csrc/dynamo/extra_state.h>
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
: backend{backend} {
: backend{py::cast<py::object>(get_backend(backend))} {
this->guard_manager = guarded_code.attr("guard_manager");
this->code = guarded_code.attr("code");
this->compile_id = guarded_code.attr("compile_id");
@ -52,6 +52,7 @@ void CacheEntry::invalidate(py::object deleted_guard_manager) {
this->guard_manager = std::move(deleted_guard_manager);
this->root_mgr = nullptr;
this->trace_annotation = "Invalidated";
this->backend = py::none();
}
void CacheEntry::update_diff_guard_root_manager() {

View File

@ -53,7 +53,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
// diff guard root guard manager if exists
void* diff_guard_root_mgr{nullptr};
// backend used to create this cache entry
PyObject* backend{nullptr};
py::object backend;
// Reference to owning ExtraState
ExtraState* _owner{nullptr};
// Reference to this CacheEntry's location in owner's linked list

View File

@ -152,8 +152,8 @@ void lookup(
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
// Check backend. Py_False means run only mode.
bool valid =
backend == Py_False || backend_match(cache_entry.backend, backend);
bool valid = backend == Py_False ||
backend_match(cache_entry.backend.ptr(), backend);
if (valid) {
try {