Initial Python 3.12 build fixes (#106083)

This compiles with python 3.12
You can get numpy from https://anaconda.org/scientific-python-nightly-wheels/numpy/files so that you don't need to remove numpy from test files.

Basic core tests work but obviously dynamo and first class dims don't work.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106083
Approved by: https://github.com/ezyang
This commit is contained in:
albanD 2023-08-25 13:23:48 +00:00 committed by PyTorch MergeBot
parent 97a291f6bd
commit b9472decf8
7 changed files with 63 additions and 4 deletions

View File

@ -4,6 +4,16 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/csrc/utils/python_compat.h>
// Many APIs have changed/don't exist anymore
#if IS_PYTHON_3_12_PLUS
// Re-enable this some day
#else
#include "minpybind.h"
#include <frameobject.h>
#include <opcode.h>
@ -12,7 +22,6 @@
#include <iostream>
#include <vector>
//#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/Export.h>
#include <ATen/functorch/BatchedTensorImpl.h>
#include <ATen/functorch/DynamicLayer.h>
@ -3252,3 +3261,5 @@ PyObject* Dim_init() {
return nullptr;
}
}
#endif

View File

@ -14,3 +14,5 @@ filelock
networkx
jinja2
fsspec
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"

View File

@ -1693,6 +1693,10 @@ def compile(model: Optional[Callable] = None, *,
"""
_C._log_api_usage_once("torch.compile")
# Temporary until we get proper support for python 3.12
if sys.version_info >= (3, 12):
raise RuntimeError("Dynamo is not supported on Python 3.12+")
# Decorator mode
if model is None:
def fn(model: Callable):

View File

@ -17,6 +17,13 @@
#if IS_PYTHON_3_11_PLUS
// Problem in CPython includes when mixing core and non-core build
// The fix was not backported to 3.12 so this is needed here
// https://github.com/python/cpython/issues/105268
#if IS_PYTHON_3_12_PLUS
#undef _PyGC_FINALIZED
#endif
#define Py_BUILD_CORE
#include <internal/pycore_pystate.h>
#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt
@ -28,7 +35,8 @@
// As a simple way to reduce the impact of ABI changes on the CPython side, this check forces
// us to manually re-check that the function didn't change on the next major version
#if PY_VERSION_HEX >= 0x030C0000 // 3.12
#error "Please ensure that the functions below still match the CPython implementation for 3.12"
// Spoiler alert: They don't! This will be done in a follow up.
// #error "Please ensure that the functions below still match the CPython implementation for 3.12"
#endif
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1079
@ -78,8 +86,13 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
/* Free vars have not been initialized -- Do that */
PyCodeObject *co = frame->f_code;
#if IS_PYTHON_3_12_PLUS
PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
int offset = co->co_nlocals + co->co_ncellvars;
#else
PyObject *closure = frame->f_func->func_closure;
int offset = co->co_nlocals + co->co_nplaincellvars;
#endif
for (int i = 0; i < co->co_nfreevars; ++i) {
PyObject *o = PyTuple_GET_ITEM(closure, i);
Py_INCREF(o);
@ -338,7 +351,11 @@ THP_PyFrame_Clear(_PyInterpreterFrame *frame)
}
Py_XDECREF(frame->frame_obj);
Py_XDECREF(frame->f_locals);
#if IS_PYTHON_3_12_PLUS
Py_DECREF(frame->f_funcobj);
#else
Py_DECREF(frame->f_func);
#endif
Py_DECREF(frame->f_code);
}

View File

@ -1,13 +1,12 @@
#pragma once
#include <Python.h>
#include <torch/csrc/utils/python_compat.h>
// Functions that need to be copied from the CPython source
// should go in cpython_defs.c. Copying is required when, e.g.,
// we need to call internal CPython functions that are not exposed.
#if IS_PYTHON_3_11_PLUS
#if IS_PYTHON_3_11_PLUS && !(IS_PYTHON_3_12_PLUS)
#include <internal/pycore_frame.h>

View File

@ -4,6 +4,13 @@
#include <opcode.h>
#include <stdbool.h>
// Problem in CPython includes when mixing core and non-core build
// The fix was not backported to 3.12 so this is needed here
// https://github.com/python/cpython/issues/105268
#if IS_PYTHON_3_12_PLUS
#undef _PyGC_FINALIZED
#endif
// see https://bugs.python.org/issue35886
#if PY_VERSION_HEX >= 0x03080000
#define Py_BUILD_CORE
@ -39,7 +46,11 @@ static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObj
return res; \
}
#if IS_PYTHON_3_12_PLUS
DECLARE_PYOBJ_ATTR(f_funcobj)
#else
DECLARE_PYOBJ_ATTR(f_func)
#endif
DECLARE_PYOBJ_ATTR(f_globals)
DECLARE_PYOBJ_ATTR(f_builtins)
DECLARE_PYOBJ_ATTR(f_locals)
@ -79,7 +90,11 @@ static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObj
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
#if IS_PYTHON_3_12_PLUS
{"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL},
#else
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
#endif
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
@ -723,10 +738,18 @@ inline static PyObject* eval_custom_code(
// Generate Python function object and _PyInterpreterFrame in a way similar to
// https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
#if IS_PYTHON_3_12_PLUS
// Most of these don't exist in 3.12 anymore.
// _PyFunction_CopyWithNewCode and _PyFrame_InitializeSpecials in particular
PyFunctionObject* func;
PyErr_SetString(PyExc_RuntimeError, "Dynamo is not supported in Python 3.12 yet");
return NULL;
#else
PyFunctionObject* func = _PyFunction_CopyWithNewCode((PyFunctionObject*) frame->f_func, code);
if (func == NULL) {
return NULL;
}
#endif
size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
// THP_EVAL_API_FRAME_OBJECT (_PyInterpreterFrame) is a regular C struct, so
@ -739,7 +762,9 @@ inline static PyObject* eval_custom_code(
Py_INCREF(func);
// consumes reference to func
#if !(IS_PYTHON_3_12_PLUS)
_PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus);
#endif
PyObject** fastlocals_old = frame->localsplus;
PyObject** fastlocals_new = shadow->localsplus;

View File

@ -10,6 +10,7 @@ extern "C" {
// PyTorch-only compat functions
#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
PYCAPI_COMPAT_STATIC_INLINE(int)
PyCode_GetNCellvars(PyCodeObject* code) {