diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index db59983e5d0..b1c034aa1a3 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -21,6 +21,11 @@ #include "arena.h" #include "python_variable_simple.h" +#if IS_PYTHON_3_11_PLUS +#define Py_BUILD_CORE +#include "internal/pycore_opcode.h" +#undef Py_BUILD_CORE +#endif // C++ API functions for objects to // * construct the object, returning a ref-counted handle @@ -1433,33 +1438,6 @@ bool relevant_op(_Py_CODEUNIT c) { } } -py::object getname(PyCodeObject* code, _Py_CODEUNIT c) { - PyObject* names = NULL; - switch(_Py_OPCODE(c)) { - case STORE_NAME: - case STORE_GLOBAL: - names = code->co_names; - break; - case STORE_FAST: -#if PY_VERSION_HEX < 0x030b0000 - names = code->co_varnames; -#else - names = PyCode_GetVarnames(code); -#endif - break; - case STORE_DEREF: -#if PY_VERSION_HEX < 0x030b0000 - names = code->co_cellvars; -#else - names = PyCode_GetCellvars(code); -#endif - break; - default: - return py::object(); - } - return py::object::steal(PySequence_GetItem(names, _Py_OPARG(c))); -} - py::object create_dim(py::object name, py::handle size) { auto d = Dim::create(std::move(name)); if (!py::is_none(size)) { @@ -1487,10 +1465,54 @@ py::object create_dimlist(py::object name, py::handle size) { // Python wrappers that make new reflection primitives available for older runtimes -#if PY_VERSION_HEX < 0x030b0000 +#if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif +struct PyInstDecoder { + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} + void next() { + #if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; + #endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); + #if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; + #endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } + + py::object name() { + py::object names; + switch(opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = py::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = py::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = py::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return py::object(); + } + return py::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } +private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; +}; + template static PyObject* _dims(PyObject *self, PyObject *const *args, @@ -1518,15 +1540,22 @@ static PyObject* _dims(PyObject *self, PyThreadState* state = PyThreadState_GET(); auto f = py::obj::steal(PyThreadState_GetFrame(state)); auto c = py::obj::steal(PyFrame_GetCode(f.ptr())); - auto code = _PyCode_CODE(c.ptr()); - int first = PyFrame_GetLasti(f.ptr()) / 2 + 1; - auto unpack = code[first]; - int names_start = first; - if (relevant_op(unpack)) { + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); + #if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { + decoder.next(); + } + #endif + decoder.next(); + + if (relevant_op(decoder.opcode())) { found_ndims = 1; - } else if (_Py_OPCODE(unpack) == UNPACK_SEQUENCE) { - found_ndims = _Py_OPARG(unpack); - names_start++; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); } if (specified_ndims == -1) { @@ -1542,11 +1571,13 @@ static PyObject* _dims(PyObject *self, auto genobject = [&](int i) -> py::object { py::object name; if (i < found_ndims) { - name = getname(c.ptr(), code[names_start + i]); + name = decoder.name(); } if (!name.ptr()) { name = py::unicode_from_format("d%d", i); found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); } return create_object(std::move(name), sizes != -1 ? py::sequence_view(py_sizes)[i] : py::handle(Py_None)); }; diff --git a/test/functorch/test_dims.py b/test/functorch/test_dims.py index afd4709416b..b642c8da4e1 100644 --- a/test/functorch/test_dims.py +++ b/test/functorch/test_dims.py @@ -195,6 +195,14 @@ class TestMin(TestCase): embeddings[indices[i], f] += values[i, f] + def test_adapt(self): + def f(): + ci, co = dims() + # python 3.11 adapts bytecode after a number of iterations + # check that we still match names correctly + for i in range(10): + f() + @skipIf(not TEST_CUDA, "no CUDA") def test_attn_cuda(self): # size from the BERT paper, 90% pretraining of sequence length 128 @@ -624,6 +632,5 @@ class TestMinFunctorchOnly(TestMin): for n in skip_functorch_only: setattr(TestMinFunctorchOnly, n, skip("skip_functorch_only")(lambda self: None)) - if __name__ == '__main__': run_tests()