Fix passing frame to callback (#91170)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91170
Approved by: https://github.com/ezyang
This commit is contained in:
albanD 2022-12-22 08:00:49 -05:00 committed by PyTorch MergeBot
parent eadd557266
commit c7302075f3
3 changed files with 134 additions and 6 deletions

View File

@ -41,12 +41,12 @@ fi
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers
set +ex
grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h torch/
grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h --exclude=eval_frame.c torch/
PYLONG_API_CHECK=$?
if [[ $PYLONG_API_CHECK == 0 ]]; then
echo "Usage of PyLong_{From,As}{Unsigned}Long API may lead to overflow errors on Windows"
echo "because \`sizeof(long) == 4\` and \`sizeof(unsigned long) == 4\`."
echo "Please include \"torch/csrc/python_numbers.h\" and use the correspoding APIs instead."
echo "Please include \"torch/csrc/utils/python_numbers.h\" and use the correspoding APIs instead."
echo "PyLong_FromLong -> THPUtils_packInt32 / THPUtils_packInt64"
echo "PyLong_AsLong -> THPUtils_unpackInt (32-bit) / THPUtils_unpackLong (64-bit)"
echo "PyLong_FromUnsignedLong -> THPUtils_packUInt32 / THPUtils_packUInt64"

View File

@ -1,9 +1,17 @@
import dataclasses
import sys
import types
from typing import Callable, Dict, List, NamedTuple, Optional, OrderedDict, Union
from typing_extensions import Protocol
if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame
DynamoFrameType = eval_frame._PyInterpreterFrame
else:
DynamoFrameType = types.FrameType
class GuardFail(NamedTuple):
# A string repr of the piece of failed guard code we eval-ed
@ -33,7 +41,9 @@ class GuardedCode:
class DynamoCallbackFn(Protocol):
def __call__(
self, frame: types.FrameType, cache_size: int
self,
frame: DynamoFrameType,
cache_size: int,
) -> Optional[GuardedCode]:
...

View File

@ -152,6 +152,106 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
return 0;
}
// We need to be able to return the _PyInterpreterFrame to python so create
// a python binding for it
typedef struct THPPyInterpreterFrame {
PyObject_HEAD
_PyInterpreterFrame* frame; // Borrowed reference
} THPPyInterpreterFrame;
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
#define DECLARE_PYOBJ_ATTR(name) \
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
PyObject* res = (PyObject*)self->frame->name; \
Py_XINCREF(res); \
return res; \
}
DECLARE_PYOBJ_ATTR(f_func)
DECLARE_PYOBJ_ATTR(f_globals)
DECLARE_PYOBJ_ATTR(f_builtins)
DECLARE_PYOBJ_ATTR(f_locals)
DECLARE_PYOBJ_ATTR(f_code)
DECLARE_PYOBJ_ATTR(frame_obj)
#undef DECLARE_PYOBJ_ATTR
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
return res;
}
// This is not a true attribute of the class but we do access it in python and it is hard to implement
// on the python side, so do it here:
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPDevice_properties[] = {
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
{NULL}};
PyTypeObject THPPyInterpreterFrameType = {
PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame", /* tp_name */
sizeof(THPPyInterpreterFrame), /* tp_basicsize */
0, /* tp_itemsize */
NULL, /* tp_dealloc */
0, /* tp_vectorcall_offset */
NULL, /* tp_getattr */
NULL, /* tp_setattr */
NULL, /* tp_reserved */
NULL, /* tp_repr */
NULL, /* tp_as_number */
NULL, /* tp_as_sequence */
NULL, /* tp_as_mapping */
NULL, /* tp_hash */
NULL, /* tp_call */
NULL, /* tp_str */
NULL, /* tp_getattro */
NULL, /* tp_setattro */
NULL, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
NULL, /* tp_doc */
NULL, /* tp_traverse */
NULL, /* tp_clear */
NULL, /* tp_richcompare */
0, /* tp_weaklistoffset */
NULL, /* tp_iter */
NULL, /* tp_iternext */
NULL, /* tp_methods */
NULL, /* tp_members */
THPDevice_properties, /* tp_getset */
NULL, /* tp_base */
NULL, /* tp_dict */
NULL, /* tp_descr_get */
NULL, /* tp_descr_set */
0, /* tp_dictoffset */
NULL, /* tp_init */
NULL, /* tp_alloc */
NULL, /* tp_new */
};
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
if (!self)
return NULL;
self->frame = frame;
return self;
}
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
@ -292,8 +392,14 @@ inline static void enable_eval_frame_default(PyThreadState* tstate) {
static inline PyObject* call_callback(
PyObject* callable,
PyObject* frame,
THP_EVAL_API_FRAME_OBJECT* _frame,
long cache_len) {
#if IS_PYTHON_3_11_PLUS
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
#else
PyFrameObject* frame = _frame;
#endif
PyObject* args = Py_BuildValue("(Ol)", frame, cache_len);
if (args == NULL) {
return NULL;
@ -573,7 +679,7 @@ static PyObject* _custom_eval_frame(
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// that gets re-interpreted as a PyObject (which it is NOT!)
PyObject* result =
call_callback(callback, (PyObject*)frame, cache_size(extra));
call_callback(callback, frame, cache_size(extra));
if (result == NULL) {
// internal exception, returning here will leak the exception into user code
// this is useful for debugging -- but we dont want it to happen outside of
@ -761,5 +867,17 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
noargs = PyTuple_New(0);
dotzerokey = PyUnicode_InternFromString(".0");
return PyModule_Create(&_module);
PyObject* module = PyModule_Create(&_module);
#if IS_PYTHON_3_11_PLUS
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
return NULL;
}
Py_INCREF(&THPPyInterpreterFrameType);
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
return NULL;
}
#endif
return module;
}