mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Attempt #2 for https://github.com/pytorch/pytorch/pull/117875 to fix https://github.com/pytorch/pytorch/issues/112090. Summary of changes: - ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor) - Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated. - ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor) - CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free. - code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references. - Added tests that check for memory leaks and cache deletion operations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119107 Approved by: https://github.com/jansel
55 lines
1.7 KiB
C++
55 lines
1.7 KiB
C++
#include <torch/csrc/dynamo/init.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/dynamo/cache_entry.h>
|
|
#include <torch/csrc/dynamo/eval_frame.h>
|
|
#include <torch/csrc/dynamo/extra_state.h>
|
|
#include <torch/csrc/dynamo/guards.h>
|
|
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
|
|
|
static struct PyModuleDef _module =
|
|
{PyModuleDef_HEAD_INIT, "torch._C._dynamo", "", -1, nullptr};
|
|
|
|
namespace torch {
|
|
namespace dynamo {
|
|
using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
|
|
|
|
void initDynamoBindings(PyObject* torch) {
|
|
PyObject* dynamo = PyModule_Create(&_module);
|
|
if (dynamo == nullptr || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) {
|
|
throw python_error();
|
|
}
|
|
|
|
PyObject* eval_frame = torch_c_dynamo_eval_frame_init();
|
|
if (eval_frame == nullptr ||
|
|
PyModule_AddObject(dynamo, "eval_frame", eval_frame) != 0) {
|
|
throw python_error();
|
|
}
|
|
|
|
PyObject* guards = torch_c_dynamo_guards_init();
|
|
if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
|
|
throw python_error();
|
|
}
|
|
|
|
PyObject* compiled_autograd = torch_c_dynamo_compiled_autograd_init();
|
|
if (compiled_autograd == nullptr ||
|
|
PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto m = py::handle(eval_frame).cast<py::module>();
|
|
|
|
py::class_<CacheEntry>(m, "_CacheEntry")
|
|
.def_readonly("check_fn", &CacheEntry::check_fn)
|
|
.def_readonly("code", &CacheEntry::code)
|
|
.def_property_readonly("next", &CacheEntry::next);
|
|
|
|
py::class_<ExtraState>(m, "_ExtraState")
|
|
.def("invalidate", &ExtraState::invalidate);
|
|
|
|
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
|
|
}
|
|
|
|
} // namespace dynamo
|
|
} // namespace torch
|