#include #include #include #include CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) : backend{py::cast(get_backend(backend))} { this->guard_manager = guarded_code.attr("guard_manager"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); py::object trace_annotation = guarded_code.attr("trace_annotation"); const char* trace_annotation_str = PyUnicode_AsUTF8(trace_annotation.ptr()); if (trace_annotation) { this->trace_annotation = std::string(trace_annotation_str); } else { this->trace_annotation = "Unknown"; } this->root_mgr = torch::dynamo::convert_to_root_guard_manager( this->guard_manager.attr("root")); this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager( this->guard_manager.attr("diff_guard_root")); } C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( "-Wdeprecated-copy-with-user-provided-dtor") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") // NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { // prevent guard_manager from use-after-free when invalidating this->guard_manager.attr("cache_entry") = py::none(); this->guard_manager.attr("extra_state") = py::none(); } C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() py::object CacheEntry::next() { NULL_CHECK(this->_owner); auto it = this->_owner_loc; ++it; if (it == this->_owner->cache_entry_list.end()) { return py::none(); } return py::cast(*it, py::return_value_policy::reference); } void CacheEntry::invalidate(py::object deleted_guard_manager) { // Keep the current pointer alive but make the fields as if no-op this->guard_manager.attr("cache_entry") = py::none(); this->guard_manager.attr("extra_state") = py::none(); this->code = py::none(); this->guard_manager = std::move(deleted_guard_manager); this->root_mgr = nullptr; this->trace_annotation = "Invalidated"; this->backend = py::none(); } void CacheEntry::update_diff_guard_root_manager() { this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager( this->guard_manager.attr("diff_guard_root")); } PyCodeObject* CacheEntry_get_code(CacheEntry* e) { return (PyCodeObject*)e->code.ptr(); } const char* CacheEntry_get_trace_annotation(CacheEntry* e) { return e->trace_annotation.c_str(); } PyObject* CacheEntry_to_obj(CacheEntry* e) { if (!e) { return py::none().release().ptr(); } return py::cast(e, py::return_value_policy::reference).release().ptr(); } PyObject* get_backend(PyObject* callback) { py::handle handle = py::handle(callback); while (py::hasattr(handle, "_torchdynamo_orig_backend")) { handle = handle.attr("_torchdynamo_orig_backend"); } return handle.ptr(); }