mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/17112 ```python print("good", torch.randn(5,5,5).max(1)) print("terrible", torch.randn(5,5,10).max(1)) print("not as good", torch.randn(5,5,500).max(1)) print ("old behaviour = gold standard") print(tuple(torch.randn(5,5,5).max(1))) print(tuple(torch.randn(5,5,10).max(1))) print(tuple(torch.randn(5,5,500).max(1))) ``` now gives ``` >>> import torch >>> print("good", torch.randn(5,5,5).max(1)) good torch.return_types.max( values=tensor([[ 1.2821, 1.8063, 1.8075, 1.3082, -0.1267], [ 0.3437, 0.7353, 1.2619, 0.7557, 1.6662], [ 0.8583, 1.8906, 1.0246, 1.7598, 1.1184], [ 1.7821, 0.0230, 0.9452, 1.0318, 1.0823], [ 0.4116, -0.0379, -0.1843, 1.4129, 1.8796]]), indices=tensor([[4, 4, 3, 2, 1], [1, 2, 4, 1, 1], [2, 4, 0, 2, 1], [0, 2, 0, 3, 1], [0, 4, 4, 4, 4]])) >>> print("terrible", torch.randn(5,5,10).max(1)) terrible torch.return_types.max( values=tensor([[ 2.1272, 1.3664, 2.2067, 1.3974, -0.0883, 1.2505, 1.0074, 1.1217, 0.3849, 0.6936], [ 0.6288, -0.4560, 1.2748, 1.5482, 1.2777, 1.6874, 0.7151, 0.6041, 1.3572, 1.6232], [ 1.6703, 1.0075, 1.6480, 2.2839, 1.3390, 0.4938, 1.6449, 1.7628, 0.8141, 2.5714], [ 0.7079, 1.8677, 3.2478, 1.5591, 2.4870, 0.8635, -0.1450, 1.6923, 1.4924, 1.6298], [ 2.4056, 0.8002, 0.9317, 0.7455, 0.7866, 2.1191, 0.3492, 1.2095, 1.8637, 1.7470]]), indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4], [4, 2, 2, 1, 2, 2, 3, 1, 1, 3], [0, 3, 3, 0, 2, 1, 4, 1, 0, 1], [4, 1, 3, 0, 3, 2, 0, 1, 4, 3], [1, 0, 3, 2, 1, 0, 0, 1, 0, 1]])) >>> print("not as good", torch.randn(5,5,500).max(1)) not as good torch.return_types.max( values=tensor([[ 0.3877, 0.7873, 1.8701, ..., 0.5971, 1.6103, -0.3435], [ 1.1300, 2.2418, 1.4239, ..., 1.3943, 0.3872, 1.6475], [ 2.0656, 1.3136, 0.9896, ..., 2.3918, 0.8226, 1.0517], [ 1.1054, 0.9945, 1.0561, ..., 2.1039, 1.1524, 3.0304], [ 1.5041, 2.2809, 1.0883, ..., 0.8504, 2.4774, 1.1041]]), indices=tensor([[4, 3, 1, ..., 1, 4, 0], [4, 4, 4, ..., 3, 0, 3], [3, 0, 1, ..., 2, 2, 4], [0, 1, 1, ..., 4, 2, 2], [1, 0, 4, ..., 2, 0, 2]])) >>> print ("old behaviour = gold standard") old behaviour = gold standard >>> print(tuple(torch.randn(5,5,5).max(1))) (tensor([[ 1.1908, 1.1807, 1.3151, 1.7184, 0.3556], [ 0.3798, 0.9213, 0.3001, 1.3087, 2.2419], [ 1.4233, 1.4814, 1.9900, 1.7744, 1.3059], [ 1.0026, -0.0330, 1.3061, 1.8730, 2.0685], [ 1.3041, 1.6458, 1.3449, 1.8948, 3.6206]]), tensor([[0, 4, 3, 4, 0], [1, 1, 4, 0, 4], [4, 1, 0, 3, 3], [1, 2, 1, 4, 0], [3, 3, 0, 3, 3]])) >>> print(tuple(torch.randn(5,5,10).max(1))) (tensor([[-0.1232, 0.8275, 0.6732, 1.1223, 0.8247, 1.2851, 1.6009, 1.9979, 1.9109, 0.7313], [ 0.2260, 0.5922, 1.6928, 0.6024, 2.1158, 3.0619, 0.5653, 0.7426, 0.8316, 0.6346], [ 0.4319, 0.2231, 0.5255, 1.7620, 1.1657, 0.8875, 0.5782, 0.6506, 0.5032, 1.7097], [ 0.4137, 1.7265, 1.4260, 2.0301, 1.2244, 0.7128, 2.6345, 0.7230, 1.3553, 1.6508], [ 1.0684, 1.7195, 1.4068, 0.7076, -0.0242, 0.8474, 0.8754, 1.7108, 0.2188, 1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0], [1, 4, 0, 0, 3, 2, 0, 0, 3, 3], [2, 3, 1, 1, 4, 0, 1, 4, 4, 4], [0, 4, 1, 3, 2, 0, 2, 0, 3, 1], [1, 0, 0, 0, 0, 3, 3, 3, 2, 0]])) >>> print(tuple(torch.randn(5,5,500).max(1))) (tensor([[0.9395, 1.5572, 1.8797, ..., 2.0494, 0.8202, 0.9623], [1.7937, 0.7225, 1.8836, ..., 0.7927, 1.4976, 1.1813], [0.8558, 1.6943, 1.4192, ..., 0.8327, 1.9661, 0.4197], [1.2993, 1.4995, 0.9357, ..., 0.7810, 1.3030, 2.6216], [1.4206, 1.8315, 1.0338, ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3, ..., 3, 0, 2], [0, 1, 0, ..., 0, 4, 3], [3, 4, 3, ..., 3, 0, 0], [3, 2, 3, ..., 1, 2, 1], [1, 2, 4, ..., 3, 1, 3]])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136 Differential Revision: D14250021 Pulled By: VitalyFedyunin fbshipit-source-id: aae72f03b35980063b1ac1f07b8353eddb0c8b93
108 lines
2.6 KiB
C++
108 lines
2.6 KiB
C++
/* Copyright Python Software Foundation
|
|
*
|
|
* This file is copy-pasted from CPython source code with modifications:
|
|
* https://github.com/python/cpython/blob/master/Objects/structseq.c
|
|
* https://github.com/python/cpython/blob/2.7/Objects/structseq.c
|
|
*
|
|
* The purpose of this file is to overwrite the default behavior
|
|
* of repr of structseq to provide better printting for returned
|
|
* structseq objects from operators, aka torch.return_types.*
|
|
*
|
|
* For more information on copyright of CPython, see:
|
|
* https://github.com/python/cpython#copyright-and-license-information
|
|
*/
|
|
|
|
#include "torch/csrc/utils/structseq.h"
|
|
#include "torch/csrc/utils/six.h"
|
|
#include "structmember.h"
|
|
#include <sstream>
|
|
|
|
namespace torch {
|
|
namespace utils {
|
|
|
|
#if PY_MAJOR_VERSION == 2
|
|
PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high)
|
|
{
|
|
PyTupleObject *np;
|
|
Py_ssize_t i;
|
|
|
|
if (low < 0) {
|
|
low = 0;
|
|
}
|
|
if (high > Py_SIZE(obj)) {
|
|
high = Py_SIZE(obj);
|
|
}
|
|
if (high < low) {
|
|
high = low;
|
|
}
|
|
np = (PyTupleObject *)PyTuple_New(high-low);
|
|
if (np == nullptr) {
|
|
return nullptr;
|
|
}
|
|
for(i = low; i < high; ++i) {
|
|
PyObject *v = obj->ob_item[i];
|
|
Py_INCREF(v);
|
|
PyTuple_SET_ITEM(np, i-low, v);
|
|
}
|
|
return (PyObject *) np;
|
|
}
|
|
|
|
#define PyUnicode_AsUTF8 PyString_AsString
|
|
#endif
|
|
|
|
PyObject *returned_structseq_repr(PyStructSequence *obj) {
|
|
PyTypeObject *typ = Py_TYPE(obj);
|
|
PyObject *tup = six::toTuple(obj);
|
|
if (tup == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
std::stringstream ss;
|
|
ss << typ->tp_name << "(\n";
|
|
size_t num_elements = Py_SIZE(obj);
|
|
|
|
for (int i=0; i < num_elements; i++) {
|
|
PyObject *val, *repr;
|
|
const char *cname, *crepr;
|
|
|
|
cname = typ->tp_members[i].name;
|
|
if (cname == nullptr) {
|
|
PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
|
|
" for type %.500s", i, typ->tp_name);
|
|
Py_DECREF(tup);
|
|
return nullptr;
|
|
}
|
|
|
|
val = PyTuple_GetItem(tup, i);
|
|
if (val == nullptr) {
|
|
Py_DECREF(tup);
|
|
return nullptr;
|
|
}
|
|
|
|
repr = PyObject_Repr(val);
|
|
if (repr == nullptr) {
|
|
Py_DECREF(tup);
|
|
return nullptr;
|
|
}
|
|
|
|
crepr = PyUnicode_AsUTF8(repr);
|
|
Py_DECREF(repr);
|
|
if (crepr == nullptr) {
|
|
Py_DECREF(tup);
|
|
return nullptr;
|
|
}
|
|
|
|
ss << cname << '=' << crepr;
|
|
if (i < num_elements - 1) {
|
|
ss << ",\n";
|
|
}
|
|
}
|
|
ss << ")";
|
|
|
|
Py_DECREF(tup);
|
|
return PyUnicode_FromString(ss.str().c_str());
|
|
}
|
|
|
|
}
|
|
}
|