Fix c++ implementation of strip_function_call (#147436)

#143063 was missing handling a couple UCS cases as well as had some bugs in the way it dealt with errors.

- Fix all the UCS handling (and make some of the common code more common)
- Make sure all the error paths return `nullptr`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147436
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Orenstein 2025-02-18 16:58:41 -08:00 committed by PyTorch MergeBot
parent af31640391
commit be0df96b50

View File

@ -36,71 +36,82 @@ using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
namespace {
template <typename T>
bool unicode_is_literal_none(const T* start, const T* end) {
if (end != start + 4) {
return false;
}
return start[0] == 'N' && start[1] == 'o' && start[2] == 'n' &&
start[3] == 'e';
}
template <typename T>
THPObjectPtr strip_function_call_helper(
PyObject* original,
const T* const start,
size_t length) {
// This function is... not great.
const T* const end = start + length;
const T* curr = start;
for (auto p = start; p < end; ++p) {
if (*p == ' ' || *p == '(') {
curr = p + 1;
} else if (*p == ')' || *p == ',' || *p == '[' || *p == ']') {
if ((p > curr) && !unicode_is_literal_none(curr, p) &&
(Py_UNICODE_ISALPHA(*curr) || *curr == '_')) {
return strip_function_call_helper(nullptr, curr, p - curr);
}
// The original code skipped adding these chars...
struct StripFunctionCall {
template <typename T>
static bool unicode_is_literal_none(const T* start, const T* end) {
if (end != start + 4) {
return false;
}
return start[0] == 'N' && start[1] == 'o' && start[2] == 'n' &&
start[3] == 'e';
}
// strip_getattr_getitem
auto p = start;
for (; p < end; ++p) {
if (*p == '.' || *p == '[')
break;
// Takes a raw unicode pointer and length in code points and returns a
// new/owned reference. T will be one of Py_UCS1, Py_UCS2, Py_UCS4.
template <typename T>
static THPObjectPtr apply(
PyObject* original,
const T* const start,
size_t length) {
// This function (based on the original python) is... not great.
const T* const end = start + length;
const T* curr = start;
// All the code points we are interested in have the same values across UCS
// types.
for (auto p = start; p < end; ++p) {
if (*p == ' ' || *p == '(') {
curr = p + 1;
} else if (*p == ')' || *p == ',' || *p == '[' || *p == ']') {
if ((p > curr) && !unicode_is_literal_none(curr, p) &&
(Py_UNICODE_ISALPHA(*curr) || *curr == '_')) {
return apply(nullptr, curr, p - curr);
}
// The original code skipped adding these chars...
}
}
// strip_getattr_getitem
auto p = start;
for (; p < end; ++p) {
if (*p == '.' || *p == '[')
break;
}
if (p == end && original) {
return THPObjectPtr::dup(original);
}
return THPObjectPtr(
PyUnicode_FromKindAndData(sizeof(*start), start, p - start));
}
};
if (p == end && original) {
return THPObjectPtr::dup(original);
}
return THPObjectPtr(
PyUnicode_FromKindAndData(sizeof(*start), start, p - start));
}
THPObjectPtr strip_function_call(PyObject* name) {
if (!PyUnicode_Check(name)) {
template <typename F>
THPObjectPtr _unicode_dispatch(PyObject* str) {
if (!PyUnicode_Check(str)) {
PyErr_SetString(PyExc_TypeError, "String expected");
return THPObjectPtr::none();
return THPObjectPtr();
}
if (PyUnicode_READY(name) != 0)
return THPObjectPtr::none();
// Remove this when we're 3.10+
if (PyUnicode_READY(str) != 0) {
// Returns -1 with an exception set on failure
return THPObjectPtr();
}
auto length = PyUnicode_GET_LENGTH(name);
switch (PyUnicode_KIND(name)) {
auto length = PyUnicode_GET_LENGTH(str);
switch (PyUnicode_KIND(str)) {
case PyUnicode_1BYTE_KIND:
return strip_function_call_helper(
name, PyUnicode_1BYTE_DATA(name), length);
return F::apply(str, PyUnicode_1BYTE_DATA(str), length);
case PyUnicode_2BYTE_KIND:
throw std::runtime_error("unimplemented - 2byte");
return F::apply(str, PyUnicode_2BYTE_DATA(str), length);
case PyUnicode_4BYTE_KIND:
throw std::runtime_error("unimplemented - 4byte");
return F::apply(str, PyUnicode_4BYTE_DATA(str), length);
default:
throw std::runtime_error("unimplemented - bad value");
// This should be impossible - throw to make the compiler happy.
throw std::runtime_error("unreachable");
}
}
@ -116,57 +127,40 @@ bool _checkParamCount(size_t nargs, size_t expected) {
return true;
}
template <typename T>
THPObjectPtr is_valid_var_name_helper(const T* start, size_t length) {
if (length < 1)
return THPObjectPtr::dup(Py_False);
struct IsValidVarName {
// Takes a raw unicode pointer and length in code points and returns a
// new/owned reference. T will be one of Py_UCS1, Py_UCS2, Py_UCS4.
template <typename T>
static THPObjectPtr apply(PyObject* original, const T* start, size_t length) {
if (length < 1)
return THPObjectPtr::dup(Py_False);
// TODO: the original code is a bit odd... check it. It just checked that the
// string starts with alnum. Then if it's all digits then it logs a warning.
// TODO: the original code is a bit odd... check it. It just checked that
// the string starts with alnum. Then if it's all digits then it logs a
// warning.
if (!Py_UNICODE_ISALNUM(*start))
return THPObjectPtr::dup(Py_False);
while (length-- > 0) {
if (!Py_UNICODE_ISDIGIT(*start++)) {
return THPObjectPtr::dup(Py_True);
if (!Py_UNICODE_ISALNUM(*start))
return THPObjectPtr::dup(Py_False);
while (length-- > 0) {
if (!Py_UNICODE_ISDIGIT(*start++)) {
return THPObjectPtr::dup(Py_True);
}
}
}
// 2 == warning
return THPObjectPtr(THPUtils_packInt32(2));
}
THPObjectPtr is_valid_var_name(PyObject* name) {
if (!PyUnicode_Check(name)) {
PyErr_SetString(PyExc_TypeError, "String expected");
return THPObjectPtr::none();
// 2 == warning
return THPObjectPtr(THPUtils_packInt32(2));
}
if (PyUnicode_READY(name) != 0) {
return THPObjectPtr::none();
}
auto length = PyUnicode_GET_LENGTH(name);
switch (PyUnicode_KIND(name)) {
case PyUnicode_1BYTE_KIND:
return is_valid_var_name_helper(PyUnicode_1BYTE_DATA(name), length);
case PyUnicode_2BYTE_KIND:
return is_valid_var_name_helper(PyUnicode_2BYTE_DATA(name), length);
case PyUnicode_4BYTE_KIND:
return is_valid_var_name_helper(PyUnicode_4BYTE_DATA(name), length);
default:
throw std::runtime_error("unimplemented - bad value");
}
}
};
PyObject* _strip_function_call(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (!_checkParamCount(nargs, 1)) {
return THPObjectPtr::none().release();
return nullptr;
}
return strip_function_call(args[0]).release();
auto result = _unicode_dispatch<StripFunctionCall>(args[0]);
return result.release();
}
PyObject* _is_valid_var_name(
@ -174,9 +168,10 @@ PyObject* _is_valid_var_name(
PyObject* const* args,
Py_ssize_t nargs) {
if (!_checkParamCount(nargs, 1)) {
return THPObjectPtr::none().release();
return nullptr;
}
return is_valid_var_name(args[0]).release();
auto result = _unicode_dispatch<IsValidVarName>(args[0]);
return result.release();
}
#define PYC_FN(x) ((PyCFunction)(void (*)()) & x)