pytorch/torch/csrc/utils/python_numbers.h
gchanan 18ed2160b0
Use Index rather than Long for IntList parsing (#6674)
* Use Index rather than Long for IntList, so floating-point types convertible to ints fail the parsing.

Basically, our unpackLong code works with floating-point types that are convertible to ints, but this isn't often what you want (because of truncation).
What you actually want is to convert to an index, which will usually find such issues.

I made this the minimal change I could because:
1) I didn't want to change unpackLong because the existing code call checkLong before unpackLong, so this should be a non-issue most of the time.  And fixing this properly requires calling checkLong again, which will slow everything down.
2) An exception above is with IntList, which only checks that 1) it is a tuple or 2) it is a varargs tuple (i.e. torch.ones(1, 2, 3)).

* Fix bug.

* Don't conflict tensor and IntList bindings.

* Change function to be consistent between python 2 and 3.

* Check Index.

* Move IntList overloads in legacy new functions to below Tensor overloads.
2018-04-26 19:13:23 -04:00

120 lines
3.0 KiB
C++

#pragma once
#include "torch/csrc/python_headers.h"
#include <stdint.h>
#include <stdexcept>
#include "torch/csrc/Exceptions.h"
// largest integer that can be represented consecutively in a double
const int64_t DOUBLE_INT_MAX = 9007199254740992;
inline PyObject* THPUtils_packInt64(int64_t value) {
#if PY_MAJOR_VERSION == 2
if (sizeof(long) == sizeof(int64_t)) {
return PyInt_FromLong(static_cast<long>(value));
} else if (value <= INT32_MAX && value >= INT32_MIN) {
return PyInt_FromLong(static_cast<long>(value));
}
#endif
return PyLong_FromLongLong(value);
}
inline PyObject* THPUtils_packUInt64(uint64_t value) {
#if PY_MAJOR_VERSION == 2
if (value <= INT32_MAX) {
return PyInt_FromLong(static_cast<long>(value));
}
#endif
return PyLong_FromUnsignedLongLong(value);
}
inline PyObject* THPUtils_packDoubleAsInt(double value) {
#if PY_MAJOR_VERSION == 2
if (value <= INT32_MAX && value >= INT32_MIN) {
return PyInt_FromLong(static_cast<long>(value));
}
#endif
return PyLong_FromDouble(value);
}
inline bool THPUtils_checkLong(PyObject* obj) {
#if PY_MAJOR_VERSION == 2
return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}
inline int64_t THPUtils_unpackLong(PyObject* obj) {
int overflow;
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
if (overflow != 0) {
throw std::runtime_error("Overflow when unpacking long");
}
return (int64_t)value;
}
inline bool THPUtils_checkIndex(PyObject *obj) {
if (PyBool_Check(obj)) {
return false;
}
if (THPUtils_checkLong(obj)) {
return true;
}
auto index = THPObjectPtr(PyNumber_Index(obj));
if (!index) {
PyErr_Clear();
return false;
}
return true;
}
inline int64_t THPUtils_unpackIndex(PyObject* obj) {
if (!THPUtils_checkLong(obj)) {
auto index = THPObjectPtr(PyNumber_Index(obj));
if (index == nullptr) {
throw python_error();
}
obj = index.get();
}
return THPUtils_unpackLong(obj);
}
inline bool THPUtils_checkDouble(PyObject* obj) {
#if PY_MAJOR_VERSION == 2
return PyFloat_Check(obj) || PyLong_Check(obj) || PyInt_Check(obj);
#else
return PyFloat_Check(obj) || PyLong_Check(obj);
#endif
}
inline double THPUtils_unpackDouble(PyObject* obj) {
if (PyFloat_Check(obj)) {
return PyFloat_AS_DOUBLE(obj);
}
if (PyLong_Check(obj)) {
int overflow;
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (overflow != 0) {
throw std::runtime_error("Overflow when unpacking double");
}
if (value > DOUBLE_INT_MAX || value < -DOUBLE_INT_MAX) {
throw std::runtime_error("Precision loss when unpacking double");
}
return (double)value;
}
#if PY_MAJOR_VERSION == 2
if (PyInt_Check(obj)) {
return (double)PyInt_AS_LONG(obj);
}
#endif
double value = PyFloat_AsDouble(obj);
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
return value;
}