pytorch/torch/csrc/utils/structseq.cpp
Gao, Xiang 11c89dde55 Allow structseq to be input of operators where tuple is expected (#17208)
Summary:
Currently the following code gives an error on python 2 because `ret` is a structseq which is not a tuple
```python
ret = a.max(dim=0)
ret1 = torch.max(a, dim=0, out=ret)
```

This PR modify tuple check in python arg parser to allow structseq to be input of operators where tuple is expected, which would make the above code work.

Depend on: https://github.com/pytorch/pytorch/pull/17136
Partially fixes: https://github.com/pytorch/pytorch/issues/16813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17208

Differential Revision: D14280198

Pulled By: VitalyFedyunin

fbshipit-source-id: beffebfd3951c4f5c7c8fe99a5847616a89491f3
2019-03-11 11:33:35 -07:00

103 lines
2.5 KiB
C++

/* Copyright Python Software Foundation
*
* This file is copy-pasted from CPython source code with modifications:
* https://github.com/python/cpython/blob/master/Objects/structseq.c
* https://github.com/python/cpython/blob/2.7/Objects/structseq.c
*
* The purpose of this file is to overwrite the default behavior
* of repr of structseq to provide better printting for returned
* structseq objects from operators, aka torch.return_types.*
*
* For more information on copyright of CPython, see:
* https://github.com/python/cpython#copyright-and-license-information
*/
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/utils/six.h"
#include "structmember.h"
#include <sstream>
namespace torch {
namespace utils {
#if PY_MAJOR_VERSION == 2
PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high)
{
PyTupleObject *np;
Py_ssize_t i;
if (low < 0) {
low = 0;
}
if (high > Py_SIZE(obj)) {
high = Py_SIZE(obj);
}
if (high < low) {
high = low;
}
np = (PyTupleObject *)PyTuple_New(high-low);
if (np == nullptr) {
return nullptr;
}
for(i = low; i < high; ++i) {
PyObject *v = obj->ob_item[i];
Py_INCREF(v);
PyTuple_SET_ITEM(np, i-low, v);
}
return (PyObject *) np;
}
#define PyUnicode_AsUTF8 PyString_AsString
#endif
PyObject *returned_structseq_repr(PyStructSequence *obj) {
PyTypeObject *typ = Py_TYPE(obj);
THPObjectPtr tup = six::maybeAsTuple(obj);
if (tup == nullptr) {
return nullptr;
}
std::stringstream ss;
ss << typ->tp_name << "(\n";
size_t num_elements = Py_SIZE(obj);
for (int i=0; i < num_elements; i++) {
PyObject *val, *repr;
const char *cname, *crepr;
cname = typ->tp_members[i].name;
if (cname == nullptr) {
PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
" for type %.500s", i, typ->tp_name);
return nullptr;
}
val = PyTuple_GetItem(tup.get(), i);
if (val == nullptr) {
return nullptr;
}
repr = PyObject_Repr(val);
if (repr == nullptr) {
return nullptr;
}
crepr = PyUnicode_AsUTF8(repr);
Py_DECREF(repr);
if (crepr == nullptr) {
return nullptr;
}
ss << cname << '=' << crepr;
if (i < num_elements - 1) {
ss << ",\n";
}
}
ss << ")";
return PyUnicode_FromString(ss.str().c_str());
}
}
}