#include #include #include #include #include #include #include #include #if IS_PYTHON_3_12_PLUS #define _PyCode_GetExtra PyUnstable_Code_GetExtra #define _PyCode_SetExtra PyUnstable_Code_SetExtra #endif Py_ssize_t extra_index = -1; CacheEntry* ExtraState::get_first_entry() { if (this->cache_entry_list.empty()) { return nullptr; } return &this->cache_entry_list.front(); } ExtraState::ExtraState(PyCodeObject* orig_code_arg) : orig_code(orig_code_arg) {} void ExtraState::move_to_front(CacheEntry* cache_entry) { CHECK(cache_entry->_owner == this); CHECK(!this->cache_entry_list.empty()); CHECK(cache_entry == &*cache_entry->_owner_loc); this->cache_entry_list.splice( this->cache_entry_list.begin(), this->cache_entry_list, cache_entry->_owner_loc); } void ExtraState::move_to_back(CacheEntry* cache_entry) { CHECK(cache_entry->_owner == this); CHECK(!this->cache_entry_list.empty()); CHECK(cache_entry == &*cache_entry->_owner_loc); this->cache_entry_list.splice( this->cache_entry_list.end(), this->cache_entry_list, cache_entry->_owner_loc); } void ExtraState::invalidate( CacheEntry* cache_entry, py::object deleted_guard_manager) { // Sometimes setting the cache_entry->code to None causes the orig_code to be // freed. This calls destroy_extra_state, which deletes the extra_state and // all the cache_entries. This causes the `this` pointer to be a dangling // pointer, causing a segfault. So, we manually inc/dec ref the original code // pointer to prevent triggering of destroy_extra_state while the invalidate // function is running. Py_INCREF(this->orig_code); CHECK(cache_entry->_owner == this); CHECK(!this->cache_entry_list.empty()); CHECK(cache_entry == &*cache_entry->_owner_loc); cache_entry->invalidate(std::move(deleted_guard_manager)); // Move the cache entry to the end of the list because these will always // return False. cache_entry->_owner->move_to_back(cache_entry); Py_DECREF(this->orig_code); } CacheEntry* extract_cache_entry(ExtraState* extra_state) { if (extra_state == nullptr) { return nullptr; } return extra_state->get_first_entry(); } FrameState* extract_frame_state(ExtraState* extra_state) { if (extra_state == nullptr) { return nullptr; } return (FrameState*)extra_state->frame_state.ptr(); } FrameExecStrategy extra_state_get_exec_strategy(ExtraState* extra_state) { return extra_state->strategy; } void extra_state_set_exec_strategy( ExtraState* extra_state, FrameExecStrategy strategy) { extra_state->strategy = strategy; } ExtraState* get_extra_state(PyCodeObject* code) { ExtraState* extra = nullptr; _PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra); return extra; } void destroy_extra_state(void* obj) { ExtraState* extra = (ExtraState*)obj; delete extra; } void set_extra_state(PyCodeObject* code, ExtraState* extra_state) { ExtraState* old_extra_state = get_extra_state(code); CHECK(extra_state == nullptr || old_extra_state != extra_state); _PyCode_SetExtra((PyObject*)code, extra_index, extra_state); } ExtraState* init_and_set_extra_state(PyCodeObject* code) { // Invariant - Extra state should not have been set before, therefore it // should be nullptr. CHECK(get_extra_state(code) == nullptr); ExtraState* extra_state = new ExtraState(code); NULL_CHECK(extra_state); set_extra_state(code, extra_state); // freed by destroy_extra_state (since we need to pass these objects to C) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return extra_state; } static bool backend_match(PyObject* saved_backend, PyObject* backend) { // Pointer equality check for common case if (saved_backend != backend) { int result = PyObject_RichCompareBool(saved_backend, backend, Py_EQ); // Check for exception if (result == -1) { PyErr_Clear(); return false; } return (result == 1); } return true; } void lookup( ExtraState* extra_state, FrameLocalsMapping* f_locals, PyObject* backend, PyObject** maybe_cached_code, const char** trace_annotation, bool is_skip_guard_eval_unsafe) { size_t index = 0; CacheEntry* found = nullptr; for (const auto& entry : extra_state->precompile_entries) { if (torch::dynamo::run_root_guard_manager(entry.root_mgr, f_locals)) { *maybe_cached_code = entry.code.ptr(); return; } } for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. bool valid = backend == Py_False || backend_match(cache_entry.backend.ptr(), backend); if (valid) { try { if (is_skip_guard_eval_unsafe) { valid = torch::dynamo::run_root_guard_manager( cache_entry.diff_guard_root_mgr, f_locals); } else { valid = torch::dynamo::run_root_guard_manager( cache_entry.root_mgr, f_locals); } } catch (py::error_already_set& e) { if (guard_error_hook) { py::handle guard_error_hook_handle(guard_error_hook); py::handle f_locals_dict = (PyObject*)f_locals->to_dict(); guard_error_hook_handle( cache_entry.guard_manager, cache_entry.code, f_locals_dict, index, index == extra_state->cache_entry_list.size() - 1); } // this function is called from C, so we cannot repropagate // the exception e.restore(); *maybe_cached_code = nullptr; return; } } if (valid) { found = &cache_entry; break; } ++index; } if (found) { extra_state->move_to_front(found); *maybe_cached_code = found->code.ptr(); *trace_annotation = found->trace_annotation.c_str(); return; } *maybe_cached_code = py::none().ptr(); } CacheEntry* create_cache_entry( ExtraState* extra_state, PyObject* guarded_code, PyObject* backend) { extra_state->cache_entry_list.emplace_front(guarded_code, backend); auto new_iter = extra_state->cache_entry_list.begin(); new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; // Set guard_manager references to extra_state and CacheEntry // Warning: lifetime is controlled by C++! py::handle guard_manager = py::handle(guarded_code).attr("guard_manager"); guard_manager.attr("cache_entry") = py::cast(*new_iter, py::return_value_policy::reference); guard_manager.attr("extra_state") = py::cast(extra_state, py::return_value_policy::reference); return &*new_iter; } py::list _debug_get_cache_entry_list(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); } PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; if (extra != nullptr) { for (CacheEntry& e : extra->cache_entry_list) { result.append(py::cast(e, py::return_value_policy::reference)); } } return result; } PrecompileEntry::PrecompileEntry(py::object gm, py::object c) : guard_manager(std::move(gm)), code(std::move(c)) { TORCH_CHECK( PyCode_Check(code.ptr()), "Expecting CodeType from PrecompileEntry."); root_mgr = torch::dynamo::convert_to_root_guard_manager(guard_manager.attr("root")); } void _reset_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); } PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; if (extra != nullptr) { extra->precompile_entries.clear(); } } void _load_precompile_entry( const py::handle& code_obj, py::object guard_manager, py::object dynamo_code) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); } PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; if (extra == nullptr) { extra = init_and_set_extra_state(code); } auto entry = PrecompileEntry(std::move(guard_manager), std::move(dynamo_code)); extra->precompile_entries.push_back(std::move(entry)); } py::list _debug_get_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); } PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; if (extra != nullptr) { for (PrecompileEntry& e : extra->precompile_entries) { result.append(py::cast(e, py::return_value_policy::reference)); } } return result; }