diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 360ebc031e3..33006e29c42 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,13 +1,9 @@ # mypy: allow-untyped-defs import types -from typing import NewType +from typing import Dict, NewType from torch._dynamo.types import DynamoCallback, DynamoGuardHook -# We implement our own FrameType-like type for Python >= 3.11. So it's not actually an alias of FrameType, but still -# exposes the same interface. -_PyInterpreterFrame = NewType("_PyInterpreterFrame", types.FrameType) - # For typechecking SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object) CacheLimitHitFlag = NewType("CacheLimitHitFlag", object) @@ -31,6 +27,17 @@ class _CacheEntry: class _ExtraState: def invalidate(self, cache_entry: _CacheEntry): ... +# This is an object that encapsulates the Python FrameType, and exposes +# properties Dynamo cares about for a frame. +class _PyInterpreterFrame: + f_code: types.CodeType + f_locals: Dict[str, object] + f_globals: Dict[str, object] + f_builtins: Dict[str, object] + f_lasti: int + f_lineo: int + f_back: types.FrameType + def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... py_opcode_caches: list[int] diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index 1d0c169345d..f3b3209591a 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging -import types import weakref from dataclasses import dataclass from typing import Tuple @@ -8,6 +7,7 @@ from typing import Tuple from torch._guards import CompileId from . import config +from .types import DynamoFrameType log = logging.getLogger(__name__) @@ -100,7 +100,7 @@ class CacheSizeRelevantForFrame: return self.num_cache_entries_with_same_id_matched_objs >= limit -def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str): +def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): obj = frame.f_locals.get(local_name, None) weak_id = None try: @@ -110,7 +110,7 @@ def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str): return weak_id -def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: +def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: """ Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones in frame.f_locals. @@ -132,7 +132,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: def compute_cache_size( - frame: types.FrameType, cache_entry + frame: DynamoFrameType, cache_entry ) -> CacheSizeRelevantForFrame: # Walk the linked list to calculate the cache size num_cache_entries = 0 diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4c52bcdfe37..89119662a85 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -21,7 +21,7 @@ import typing import warnings import weakref from pathlib import Path -from types import CodeType, FrameType, FunctionType, ModuleType +from types import CodeType, FunctionType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union from typing_extensions import ParamSpec from weakref import ReferenceType @@ -138,7 +138,7 @@ except ModuleNotFoundError: if typing.TYPE_CHECKING: from .backends.registry import CompilerFn from .repro.after_dynamo import WrapBackendDebug - from .types import BytecodeHook, CacheEntry + from .types import BytecodeHook, CacheEntry, DynamoFrameType from .variables.builder import FrameStateSizeEntry @@ -257,7 +257,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: @TorchPatcher.suppress_torch_distributed_warnings -def has_tensor_in_frame(frame: FrameType) -> bool: +def has_tensor_in_frame(frame: DynamoFrameType) -> bool: """Check if the frame has torch.* related bits""" # Check if the function was decorated using torch._dynamo.optimize if frame.f_code in always_optimize_code_objects: @@ -338,7 +338,7 @@ def has_tensor_in_frame(frame: FrameType) -> bool: def exception_handler( e: Exception, code: CodeType, - frame: Optional[FrameType] = None, + frame: Optional[DynamoFrameType] = None, export: bool = False, ) -> None: record_filename = None @@ -450,7 +450,7 @@ class ConvertFrameAssert: def __call__( self, - frame: FrameType, + frame: DynamoFrameType, cache_entry: Optional[CacheEntry], hooks: Hooks, frame_state: Dict[str, Union[int, FrameStateSizeEntry]], @@ -609,7 +609,7 @@ def _compile( hooks: Hooks, cache_entry: Optional[CacheEntry], cache_size: CacheSizeRelevantForFrame, - frame: Optional[FrameType] = None, + frame: Optional[DynamoFrameType] = None, frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None, *, compile_id: CompileId, @@ -1165,7 +1165,7 @@ class ConvertFrame: def __call__( self, - frame: FrameType, + frame: DynamoFrameType, cache_entry: Optional[CacheEntry], hooks: Hooks, frame_state: Dict[str, Union[int, FrameStateSizeEntry]], @@ -1310,7 +1310,7 @@ def first_real_inst_idx(code: CodeType) -> int: class ConvertFrameProtocol(typing.Protocol): def __call__( self, - frame: FrameType, + frame: DynamoFrameType, cache_entry: Optional[CacheEntry], hooks: Hooks, frame_state: Dict[str, Union[int, FrameStateSizeEntry]], @@ -1328,7 +1328,7 @@ class CatchErrorsWrapper: def __call__( self, - frame: FrameType, + frame: DynamoFrameType, cache_entry: Optional[CacheEntry], frame_state: Dict[str, Union[int, FrameStateSizeEntry]], ) -> Optional[GuardedCode]: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e87432f4998..2de64f2f241 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -108,7 +108,14 @@ from .source import ( UnspecializedParamBufferSource, WeakRefCallSource, ) -from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 +from .types import ( # noqa: F401 + CacheEntry, + DynamoFrameType, + ExtraState, + GuardedCode, + GuardFail, + GuardFn, +) from .utils import ( common_constant_types, dict_keys_repr, @@ -2600,7 +2607,7 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reason( - cache_entry, frame: types.FrameType + cache_entry, frame: DynamoFrameType ) -> List[str]: """ Return the list of guard failure reasons using cache_entry. diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 406093119c1..ee7fb48d2ab 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -35,6 +35,7 @@ from .bytecode_transformation import ( transform_code_object, ) from .guards import CheckFunctionManager, CompileId, GuardedCode +from .types import DynamoFrameType from .utils import same @@ -164,7 +165,7 @@ def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None: def debug_insert_nops( - frame: types.FrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0 + frame: DynamoFrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0 ) -> Optional[GuardedCode]: """used to debug jump updates""" diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 298741a4e95..16278315847 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -1,5 +1,4 @@ import dataclasses -import sys import types from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union @@ -7,16 +6,11 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Un from torch._C._dynamo.eval_frame import ( _CacheEntry as CacheEntry, _ExtraState as ExtraState, + _PyInterpreterFrame as DynamoFrameType, ) from torch._guards import CompileId -if sys.version_info >= (3, 11): - from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType -else: - from types import FrameType as DynamoFrameType - - # We use a dict to store additional data per frame. FrameState = Dict[Any, Any] diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 8a30cbc536c..08132d23ac8 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -38,17 +38,20 @@ static void eval_frame_callback_set(PyObject* obj) { // https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction #if IS_PYTHON_3_11_PLUS #define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame +#else +#define THP_EVAL_API_FRAME_OBJECT PyFrameObject +#endif // IS_PYTHON_3_11_PLUS // We need to be able to return the _PyInterpreterFrame to python so create // a python binding for it typedef struct THPPyInterpreterFrame { PyObject_HEAD - _PyInterpreterFrame* frame; // Borrowed reference + THP_EVAL_API_FRAME_OBJECT* frame; // Borrowed reference PyObject* locals; } THPPyInterpreterFrame; -THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame); +THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame); #define DECLARE_PYOBJ_ATTR(name) \ static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \ @@ -57,12 +60,6 @@ static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObj return res; \ } -#if IS_PYTHON_3_12_PLUS -DECLARE_PYOBJ_ATTR(f_funcobj) -#else -DECLARE_PYOBJ_ATTR(f_func) -#endif - DECLARE_PYOBJ_ATTR(f_globals) DECLARE_PYOBJ_ATTR(f_builtins) @@ -78,22 +75,20 @@ DECLARE_PYOBJ_ATTR(f_executable) DECLARE_PYOBJ_ATTR(f_code) #endif -DECLARE_PYOBJ_ATTR(frame_obj) - #undef DECLARE_PYOBJ_ATTR -static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) { - THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous); - return res; -} - // This is not a true attribute of the class but we do access it in python and it is hard to implement // on the python side, so do it here: static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) { +#if IS_PYTHON_3_11_PLUS return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame)); +#else + return PyLong_FromLong(self->frame->f_lasti); +#endif // IS_PYTHON_3_11_PLUS } static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) { +#if IS_PYTHON_3_11_PLUS if (!self->frame->frame_obj) { return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno); } @@ -102,22 +97,24 @@ static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyO Py_RETURN_NONE; } return PyLong_FromLong(lineno); +#else + return PyLong_FromLong(self->frame->f_lineno); +#endif // IS_PYTHON_3_11_PLUS } static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) { +#if IS_PYTHON_3_11_PLUS if (!self->frame->frame_obj) { Py_RETURN_NONE; } return (PyObject*)PyFrame_GetBack(self->frame->frame_obj); +#else + return Py_XNewRef(self->frame->f_back); +#endif // IS_PYTHON_3_11_PLUS } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyGetSetDef THPPyInterpreterFrame_properties[] = { -#if IS_PYTHON_3_12_PLUS - {"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL}, -#else - {"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL}, -#endif {"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL}, {"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL}, {"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL}, @@ -126,8 +123,6 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = { #else {"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL}, #endif - {"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL}, - {"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL}, {"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL}, {"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL}, {"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL}, @@ -142,7 +137,7 @@ static PyTypeObject THPPyInterpreterFrameType = { }; -THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) { +THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame) { PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType; THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0); if (!self) @@ -152,13 +147,6 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) { return self; } - -#else - -#define THP_EVAL_API_FRAME_OBJECT PyFrameObject - -#endif - static PyObject* dynamo__custom_eval_frame_shim( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -246,6 +234,8 @@ static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) { return PyUnicode_AsUTF8(F_CODE(frame)->co_name); } +// Remember to update the type signature for DynamoCallbackFn.__call__ in +// torch/_dynamo/types.py if this function's signature changes. static PyObject* dynamo_call_callback( PyObject* callable, THP_EVAL_API_FRAME_OBJECT* _frame, @@ -253,18 +243,11 @@ static PyObject* dynamo_call_callback( CacheEntry* cache_entry, FrameState* frame_state) { -// remember to update the type signature for DynamoCallbackFn.__call__ in torch/_dynamo/types.py -// if this function changes -#if IS_PYTHON_3_11_PLUS THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame); if (frame == NULL) { return NULL; } frame->locals = locals; -#else - PyObject* frame = Py_NewRef(_frame); -#endif - PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry); PyObject* res = PyObject_CallFunction( callable, @@ -716,7 +699,7 @@ static PyObject* dynamo__custom_eval_frame( } } -#else // IS_PYTHON_3_14_PLUS +#else // !(IS_PYTHON_3_14_PLUS) // Fake definitions for everything we removed @@ -738,7 +721,7 @@ static PyTypeObject THPPyInterpreterFrameType = { .tp_getset = THPPyInterpreterFrame_properties, }; -#endif // CPython 3.14 +#endif // !(IS_PYTHON_3_14_PLUS) static PyObject* increment_working_threads(PyThreadState* tstate) { active_dynamo_threads = active_dynamo_threads + 1; @@ -909,7 +892,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) { PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); #endif -#if IS_PYTHON_3_11_PLUS if (PyType_Ready(&THPPyInterpreterFrameType) < 0) { return NULL; } @@ -917,7 +899,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) { if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) { return NULL; } -#endif skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type); if (skip_code_recursive_flag == NULL) {