mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix c++ implementation of strip_function_call (#147436)
#143063 was missing handling a couple UCS cases as well as had some bugs in the way it dealt with errors. - Fix all the UCS handling (and make some of the common code more common) - Make sure all the error paths return `nullptr` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147436 Approved by: https://github.com/jansel
This commit is contained in:
parent
af31640391
commit
be0df96b50
|
|
@ -36,71 +36,82 @@ using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
|
|||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool unicode_is_literal_none(const T* start, const T* end) {
|
||||
if (end != start + 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return start[0] == 'N' && start[1] == 'o' && start[2] == 'n' &&
|
||||
start[3] == 'e';
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
THPObjectPtr strip_function_call_helper(
|
||||
PyObject* original,
|
||||
const T* const start,
|
||||
size_t length) {
|
||||
// This function is... not great.
|
||||
const T* const end = start + length;
|
||||
const T* curr = start;
|
||||
for (auto p = start; p < end; ++p) {
|
||||
if (*p == ' ' || *p == '(') {
|
||||
curr = p + 1;
|
||||
} else if (*p == ')' || *p == ',' || *p == '[' || *p == ']') {
|
||||
if ((p > curr) && !unicode_is_literal_none(curr, p) &&
|
||||
(Py_UNICODE_ISALPHA(*curr) || *curr == '_')) {
|
||||
return strip_function_call_helper(nullptr, curr, p - curr);
|
||||
}
|
||||
// The original code skipped adding these chars...
|
||||
struct StripFunctionCall {
|
||||
template <typename T>
|
||||
static bool unicode_is_literal_none(const T* start, const T* end) {
|
||||
if (end != start + 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return start[0] == 'N' && start[1] == 'o' && start[2] == 'n' &&
|
||||
start[3] == 'e';
|
||||
}
|
||||
|
||||
// strip_getattr_getitem
|
||||
auto p = start;
|
||||
for (; p < end; ++p) {
|
||||
if (*p == '.' || *p == '[')
|
||||
break;
|
||||
// Takes a raw unicode pointer and length in code points and returns a
|
||||
// new/owned reference. T will be one of Py_UCS1, Py_UCS2, Py_UCS4.
|
||||
template <typename T>
|
||||
static THPObjectPtr apply(
|
||||
PyObject* original,
|
||||
const T* const start,
|
||||
size_t length) {
|
||||
// This function (based on the original python) is... not great.
|
||||
const T* const end = start + length;
|
||||
const T* curr = start;
|
||||
// All the code points we are interested in have the same values across UCS
|
||||
// types.
|
||||
for (auto p = start; p < end; ++p) {
|
||||
if (*p == ' ' || *p == '(') {
|
||||
curr = p + 1;
|
||||
} else if (*p == ')' || *p == ',' || *p == '[' || *p == ']') {
|
||||
if ((p > curr) && !unicode_is_literal_none(curr, p) &&
|
||||
(Py_UNICODE_ISALPHA(*curr) || *curr == '_')) {
|
||||
return apply(nullptr, curr, p - curr);
|
||||
}
|
||||
// The original code skipped adding these chars...
|
||||
}
|
||||
}
|
||||
|
||||
// strip_getattr_getitem
|
||||
auto p = start;
|
||||
for (; p < end; ++p) {
|
||||
if (*p == '.' || *p == '[')
|
||||
break;
|
||||
}
|
||||
|
||||
if (p == end && original) {
|
||||
return THPObjectPtr::dup(original);
|
||||
}
|
||||
|
||||
return THPObjectPtr(
|
||||
PyUnicode_FromKindAndData(sizeof(*start), start, p - start));
|
||||
}
|
||||
};
|
||||
|
||||
if (p == end && original) {
|
||||
return THPObjectPtr::dup(original);
|
||||
}
|
||||
|
||||
return THPObjectPtr(
|
||||
PyUnicode_FromKindAndData(sizeof(*start), start, p - start));
|
||||
}
|
||||
|
||||
THPObjectPtr strip_function_call(PyObject* name) {
|
||||
if (!PyUnicode_Check(name)) {
|
||||
template <typename F>
|
||||
THPObjectPtr _unicode_dispatch(PyObject* str) {
|
||||
if (!PyUnicode_Check(str)) {
|
||||
PyErr_SetString(PyExc_TypeError, "String expected");
|
||||
return THPObjectPtr::none();
|
||||
return THPObjectPtr();
|
||||
}
|
||||
|
||||
if (PyUnicode_READY(name) != 0)
|
||||
return THPObjectPtr::none();
|
||||
// Remove this when we're 3.10+
|
||||
if (PyUnicode_READY(str) != 0) {
|
||||
// Returns -1 with an exception set on failure
|
||||
return THPObjectPtr();
|
||||
}
|
||||
|
||||
auto length = PyUnicode_GET_LENGTH(name);
|
||||
switch (PyUnicode_KIND(name)) {
|
||||
auto length = PyUnicode_GET_LENGTH(str);
|
||||
|
||||
switch (PyUnicode_KIND(str)) {
|
||||
case PyUnicode_1BYTE_KIND:
|
||||
return strip_function_call_helper(
|
||||
name, PyUnicode_1BYTE_DATA(name), length);
|
||||
return F::apply(str, PyUnicode_1BYTE_DATA(str), length);
|
||||
case PyUnicode_2BYTE_KIND:
|
||||
throw std::runtime_error("unimplemented - 2byte");
|
||||
return F::apply(str, PyUnicode_2BYTE_DATA(str), length);
|
||||
case PyUnicode_4BYTE_KIND:
|
||||
throw std::runtime_error("unimplemented - 4byte");
|
||||
return F::apply(str, PyUnicode_4BYTE_DATA(str), length);
|
||||
default:
|
||||
throw std::runtime_error("unimplemented - bad value");
|
||||
// This should be impossible - throw to make the compiler happy.
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -116,57 +127,40 @@ bool _checkParamCount(size_t nargs, size_t expected) {
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
THPObjectPtr is_valid_var_name_helper(const T* start, size_t length) {
|
||||
if (length < 1)
|
||||
return THPObjectPtr::dup(Py_False);
|
||||
struct IsValidVarName {
|
||||
// Takes a raw unicode pointer and length in code points and returns a
|
||||
// new/owned reference. T will be one of Py_UCS1, Py_UCS2, Py_UCS4.
|
||||
template <typename T>
|
||||
static THPObjectPtr apply(PyObject* original, const T* start, size_t length) {
|
||||
if (length < 1)
|
||||
return THPObjectPtr::dup(Py_False);
|
||||
|
||||
// TODO: the original code is a bit odd... check it. It just checked that the
|
||||
// string starts with alnum. Then if it's all digits then it logs a warning.
|
||||
// TODO: the original code is a bit odd... check it. It just checked that
|
||||
// the string starts with alnum. Then if it's all digits then it logs a
|
||||
// warning.
|
||||
|
||||
if (!Py_UNICODE_ISALNUM(*start))
|
||||
return THPObjectPtr::dup(Py_False);
|
||||
while (length-- > 0) {
|
||||
if (!Py_UNICODE_ISDIGIT(*start++)) {
|
||||
return THPObjectPtr::dup(Py_True);
|
||||
if (!Py_UNICODE_ISALNUM(*start))
|
||||
return THPObjectPtr::dup(Py_False);
|
||||
while (length-- > 0) {
|
||||
if (!Py_UNICODE_ISDIGIT(*start++)) {
|
||||
return THPObjectPtr::dup(Py_True);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2 == warning
|
||||
return THPObjectPtr(THPUtils_packInt32(2));
|
||||
}
|
||||
|
||||
THPObjectPtr is_valid_var_name(PyObject* name) {
|
||||
if (!PyUnicode_Check(name)) {
|
||||
PyErr_SetString(PyExc_TypeError, "String expected");
|
||||
return THPObjectPtr::none();
|
||||
// 2 == warning
|
||||
return THPObjectPtr(THPUtils_packInt32(2));
|
||||
}
|
||||
|
||||
if (PyUnicode_READY(name) != 0) {
|
||||
return THPObjectPtr::none();
|
||||
}
|
||||
|
||||
auto length = PyUnicode_GET_LENGTH(name);
|
||||
switch (PyUnicode_KIND(name)) {
|
||||
case PyUnicode_1BYTE_KIND:
|
||||
return is_valid_var_name_helper(PyUnicode_1BYTE_DATA(name), length);
|
||||
case PyUnicode_2BYTE_KIND:
|
||||
return is_valid_var_name_helper(PyUnicode_2BYTE_DATA(name), length);
|
||||
case PyUnicode_4BYTE_KIND:
|
||||
return is_valid_var_name_helper(PyUnicode_4BYTE_DATA(name), length);
|
||||
default:
|
||||
throw std::runtime_error("unimplemented - bad value");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
PyObject* _strip_function_call(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
if (!_checkParamCount(nargs, 1)) {
|
||||
return THPObjectPtr::none().release();
|
||||
return nullptr;
|
||||
}
|
||||
return strip_function_call(args[0]).release();
|
||||
auto result = _unicode_dispatch<StripFunctionCall>(args[0]);
|
||||
return result.release();
|
||||
}
|
||||
|
||||
PyObject* _is_valid_var_name(
|
||||
|
|
@ -174,9 +168,10 @@ PyObject* _is_valid_var_name(
|
|||
PyObject* const* args,
|
||||
Py_ssize_t nargs) {
|
||||
if (!_checkParamCount(nargs, 1)) {
|
||||
return THPObjectPtr::none().release();
|
||||
return nullptr;
|
||||
}
|
||||
return is_valid_var_name(args[0]).release();
|
||||
auto result = _unicode_dispatch<IsValidVarName>(args[0]);
|
||||
return result.release();
|
||||
}
|
||||
|
||||
#define PYC_FN(x) ((PyCFunction)(void (*)()) & x)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user