mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix passing frame to callback (#91170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91170 Approved by: https://github.com/ezyang
This commit is contained in:
parent
eadd557266
commit
c7302075f3
|
|
@ -41,12 +41,12 @@ fi
|
|||
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers
|
||||
|
||||
set +ex
|
||||
grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h torch/
|
||||
grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h --exclude=eval_frame.c torch/
|
||||
PYLONG_API_CHECK=$?
|
||||
if [[ $PYLONG_API_CHECK == 0 ]]; then
|
||||
echo "Usage of PyLong_{From,As}{Unsigned}Long API may lead to overflow errors on Windows"
|
||||
echo "because \`sizeof(long) == 4\` and \`sizeof(unsigned long) == 4\`."
|
||||
echo "Please include \"torch/csrc/python_numbers.h\" and use the correspoding APIs instead."
|
||||
echo "Please include \"torch/csrc/utils/python_numbers.h\" and use the correspoding APIs instead."
|
||||
echo "PyLong_FromLong -> THPUtils_packInt32 / THPUtils_packInt64"
|
||||
echo "PyLong_AsLong -> THPUtils_unpackInt (32-bit) / THPUtils_unpackLong (64-bit)"
|
||||
echo "PyLong_FromUnsignedLong -> THPUtils_packUInt32 / THPUtils_packUInt64"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, OrderedDict, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from torch._C._dynamo import eval_frame
|
||||
|
||||
DynamoFrameType = eval_frame._PyInterpreterFrame
|
||||
else:
|
||||
DynamoFrameType = types.FrameType
|
||||
|
||||
|
||||
class GuardFail(NamedTuple):
|
||||
# A string repr of the piece of failed guard code we eval-ed
|
||||
|
|
@ -33,7 +41,9 @@ class GuardedCode:
|
|||
|
||||
class DynamoCallbackFn(Protocol):
|
||||
def __call__(
|
||||
self, frame: types.FrameType, cache_size: int
|
||||
self,
|
||||
frame: DynamoFrameType,
|
||||
cache_size: int,
|
||||
) -> Optional[GuardedCode]:
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -152,6 +152,106 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// We need to be able to return the _PyInterpreterFrame to python so create
|
||||
// a python binding for it
|
||||
|
||||
typedef struct THPPyInterpreterFrame {
|
||||
PyObject_HEAD
|
||||
_PyInterpreterFrame* frame; // Borrowed reference
|
||||
} THPPyInterpreterFrame;
|
||||
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
|
||||
|
||||
#define DECLARE_PYOBJ_ATTR(name) \
|
||||
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
|
||||
PyObject* res = (PyObject*)self->frame->name; \
|
||||
Py_XINCREF(res); \
|
||||
return res; \
|
||||
}
|
||||
|
||||
DECLARE_PYOBJ_ATTR(f_func)
|
||||
DECLARE_PYOBJ_ATTR(f_globals)
|
||||
DECLARE_PYOBJ_ATTR(f_builtins)
|
||||
DECLARE_PYOBJ_ATTR(f_locals)
|
||||
DECLARE_PYOBJ_ATTR(f_code)
|
||||
DECLARE_PYOBJ_ATTR(frame_obj)
|
||||
|
||||
#undef DECLARE_PYOBJ_ATTR
|
||||
|
||||
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
|
||||
return res;
|
||||
}
|
||||
|
||||
// This is not a true attribute of the class but we do access it in python and it is hard to implement
|
||||
// on the python side, so do it here:
|
||||
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static struct PyGetSetDef THPDevice_properties[] = {
|
||||
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
|
||||
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
|
||||
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
|
||||
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
|
||||
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
|
||||
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
|
||||
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
|
||||
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
|
||||
{NULL}};
|
||||
|
||||
PyTypeObject THPPyInterpreterFrameType = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame", /* tp_name */
|
||||
sizeof(THPPyInterpreterFrame), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
NULL, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
NULL, /* tp_getattr */
|
||||
NULL, /* tp_setattr */
|
||||
NULL, /* tp_reserved */
|
||||
NULL, /* tp_repr */
|
||||
NULL, /* tp_as_number */
|
||||
NULL, /* tp_as_sequence */
|
||||
NULL, /* tp_as_mapping */
|
||||
NULL, /* tp_hash */
|
||||
NULL, /* tp_call */
|
||||
NULL, /* tp_str */
|
||||
NULL, /* tp_getattro */
|
||||
NULL, /* tp_setattro */
|
||||
NULL, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
NULL, /* tp_doc */
|
||||
NULL, /* tp_traverse */
|
||||
NULL, /* tp_clear */
|
||||
NULL, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
NULL, /* tp_iter */
|
||||
NULL, /* tp_iternext */
|
||||
NULL, /* tp_methods */
|
||||
NULL, /* tp_members */
|
||||
THPDevice_properties, /* tp_getset */
|
||||
NULL, /* tp_base */
|
||||
NULL, /* tp_dict */
|
||||
NULL, /* tp_descr_get */
|
||||
NULL, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
NULL, /* tp_init */
|
||||
NULL, /* tp_alloc */
|
||||
NULL, /* tp_new */
|
||||
};
|
||||
|
||||
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
|
||||
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
|
||||
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return NULL;
|
||||
self->frame = frame;
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
#else
|
||||
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
|
||||
|
||||
|
|
@ -292,8 +392,14 @@ inline static void enable_eval_frame_default(PyThreadState* tstate) {
|
|||
|
||||
static inline PyObject* call_callback(
|
||||
PyObject* callable,
|
||||
PyObject* frame,
|
||||
THP_EVAL_API_FRAME_OBJECT* _frame,
|
||||
long cache_len) {
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
|
||||
#else
|
||||
PyFrameObject* frame = _frame;
|
||||
#endif
|
||||
PyObject* args = Py_BuildValue("(Ol)", frame, cache_len);
|
||||
if (args == NULL) {
|
||||
return NULL;
|
||||
|
|
@ -573,7 +679,7 @@ static PyObject* _custom_eval_frame(
|
|||
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
|
||||
// that gets re-interpreted as a PyObject (which it is NOT!)
|
||||
PyObject* result =
|
||||
call_callback(callback, (PyObject*)frame, cache_size(extra));
|
||||
call_callback(callback, frame, cache_size(extra));
|
||||
if (result == NULL) {
|
||||
// internal exception, returning here will leak the exception into user code
|
||||
// this is useful for debugging -- but we dont want it to happen outside of
|
||||
|
|
@ -761,5 +867,17 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
|
|||
|
||||
noargs = PyTuple_New(0);
|
||||
dotzerokey = PyUnicode_InternFromString(".0");
|
||||
return PyModule_Create(&_module);
|
||||
PyObject* module = PyModule_Create(&_module);
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(&THPPyInterpreterFrameType);
|
||||
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
|
||||
return NULL;
|
||||
}
|
||||
#endif
|
||||
|
||||
return module;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user