pytorch/torch/csrc/dynamo/init.cpp
William Wen ee1c2449f7 [dynamo] delete dynamo cache entry when guard function is invalidated [attempt 2] (#119107)
Attempt #2 for https://github.com/pytorch/pytorch/pull/117875 to fix https://github.com/pytorch/pytorch/issues/112090.

Summary of changes:
- ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor)
- Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated.
- ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor)
- CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free.
- code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references.
- Added tests that check for memory leaks and cache deletion operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119107
Approved by: https://github.com/jansel
2024-02-07 03:32:42 +00:00

55 lines
1.7 KiB
C++

#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/eval_frame.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/dynamo/python_compiled_autograd.h>
static struct PyModuleDef _module =
{PyModuleDef_HEAD_INIT, "torch._C._dynamo", "", -1, nullptr};
namespace torch {
namespace dynamo {
using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
void initDynamoBindings(PyObject* torch) {
PyObject* dynamo = PyModule_Create(&_module);
if (dynamo == nullptr || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) {
throw python_error();
}
PyObject* eval_frame = torch_c_dynamo_eval_frame_init();
if (eval_frame == nullptr ||
PyModule_AddObject(dynamo, "eval_frame", eval_frame) != 0) {
throw python_error();
}
PyObject* guards = torch_c_dynamo_guards_init();
if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
throw python_error();
}
PyObject* compiled_autograd = torch_c_dynamo_compiled_autograd_init();
if (compiled_autograd == nullptr ||
PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
throw python_error();
}
auto m = py::handle(eval_frame).cast<py::module>();
py::class_<CacheEntry>(m, "_CacheEntry")
.def_readonly("check_fn", &CacheEntry::check_fn)
.def_readonly("code", &CacheEntry::code)
.def_property_readonly("next", &CacheEntry::next);
py::class_<ExtraState>(m, "_ExtraState")
.def("invalidate", &ExtraState::invalidate);
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
}
} // namespace dynamo
} // namespace torch