mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Attempt #2 for https://github.com/pytorch/pytorch/pull/117875 to fix https://github.com/pytorch/pytorch/issues/112090. Summary of changes: - ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor) - Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated. - ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor) - CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free. - code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references. - Added tests that check for memory leaks and cache deletion operations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119107 Approved by: https://github.com/jansel
37 lines
986 B
C++
37 lines
986 B
C++
#include <torch/csrc/dynamo/cache_entry.h>
|
|
|
|
#include <torch/csrc/dynamo/debug_macros.h>
|
|
#include <torch/csrc/dynamo/extra_state.h>
|
|
|
|
CacheEntry::CacheEntry(const py::handle& guarded_code) {
|
|
this->check_fn = guarded_code.attr("check_fn");
|
|
this->code = guarded_code.attr("code");
|
|
}
|
|
|
|
CacheEntry::~CacheEntry() {
|
|
// prevent check_fn from use-after-free when invalidating
|
|
this->check_fn.attr("cache_entry") = py::none();
|
|
this->check_fn.attr("extra_state") = py::none();
|
|
}
|
|
|
|
py::object CacheEntry::next() {
|
|
NULL_CHECK(this->_owner);
|
|
auto it = this->_owner_loc;
|
|
++it;
|
|
if (it == this->_owner->cache_entry_list.end()) {
|
|
return py::none();
|
|
}
|
|
return py::cast(*it, py::return_value_policy::reference);
|
|
}
|
|
|
|
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
|
|
return (PyCodeObject*)e->code.ptr();
|
|
}
|
|
|
|
PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
|
if (!e) {
|
|
return py::none().release().ptr();
|
|
}
|
|
return py::cast(e, py::return_value_policy::reference).release().ptr();
|
|
}
|