mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
2e5a8cee82
1 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
2e5a8cee82 |
Customize the printing of namedtuple return (#17136)
Summary: Fixes https://github.com/pytorch/pytorch/issues/17112 ```python print("good", torch.randn(5,5,5).max(1)) print("terrible", torch.randn(5,5,10).max(1)) print("not as good", torch.randn(5,5,500).max(1)) print ("old behaviour = gold standard") print(tuple(torch.randn(5,5,5).max(1))) print(tuple(torch.randn(5,5,10).max(1))) print(tuple(torch.randn(5,5,500).max(1))) ``` now gives ``` >>> import torch >>> print("good", torch.randn(5,5,5).max(1)) good torch.return_types.max( values=tensor([[ 1.2821, 1.8063, 1.8075, 1.3082, -0.1267], [ 0.3437, 0.7353, 1.2619, 0.7557, 1.6662], [ 0.8583, 1.8906, 1.0246, 1.7598, 1.1184], [ 1.7821, 0.0230, 0.9452, 1.0318, 1.0823], [ 0.4116, -0.0379, -0.1843, 1.4129, 1.8796]]), indices=tensor([[4, 4, 3, 2, 1], [1, 2, 4, 1, 1], [2, 4, 0, 2, 1], [0, 2, 0, 3, 1], [0, 4, 4, 4, 4]])) >>> print("terrible", torch.randn(5,5,10).max(1)) terrible torch.return_types.max( values=tensor([[ 2.1272, 1.3664, 2.2067, 1.3974, -0.0883, 1.2505, 1.0074, 1.1217, 0.3849, 0.6936], [ 0.6288, -0.4560, 1.2748, 1.5482, 1.2777, 1.6874, 0.7151, 0.6041, 1.3572, 1.6232], [ 1.6703, 1.0075, 1.6480, 2.2839, 1.3390, 0.4938, 1.6449, 1.7628, 0.8141, 2.5714], [ 0.7079, 1.8677, 3.2478, 1.5591, 2.4870, 0.8635, -0.1450, 1.6923, 1.4924, 1.6298], [ 2.4056, 0.8002, 0.9317, 0.7455, 0.7866, 2.1191, 0.3492, 1.2095, 1.8637, 1.7470]]), indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4], [4, 2, 2, 1, 2, 2, 3, 1, 1, 3], [0, 3, 3, 0, 2, 1, 4, 1, 0, 1], [4, 1, 3, 0, 3, 2, 0, 1, 4, 3], [1, 0, 3, 2, 1, 0, 0, 1, 0, 1]])) >>> print("not as good", torch.randn(5,5,500).max(1)) not as good torch.return_types.max( values=tensor([[ 0.3877, 0.7873, 1.8701, ..., 0.5971, 1.6103, -0.3435], [ 1.1300, 2.2418, 1.4239, ..., 1.3943, 0.3872, 1.6475], [ 2.0656, 1.3136, 0.9896, ..., 2.3918, 0.8226, 1.0517], [ 1.1054, 0.9945, 1.0561, ..., 2.1039, 1.1524, 3.0304], [ 1.5041, 2.2809, 1.0883, ..., 0.8504, 2.4774, 1.1041]]), indices=tensor([[4, 3, 1, ..., 1, 4, 0], [4, 4, 4, ..., 3, 0, 3], [3, 0, 1, ..., 2, 2, 4], [0, 1, 1, ..., 4, 2, 2], [1, 0, 4, ..., 2, 0, 2]])) >>> print ("old behaviour = gold standard") old behaviour = gold standard >>> print(tuple(torch.randn(5,5,5).max(1))) (tensor([[ 1.1908, 1.1807, 1.3151, 1.7184, 0.3556], [ 0.3798, 0.9213, 0.3001, 1.3087, 2.2419], [ 1.4233, 1.4814, 1.9900, 1.7744, 1.3059], [ 1.0026, -0.0330, 1.3061, 1.8730, 2.0685], [ 1.3041, 1.6458, 1.3449, 1.8948, 3.6206]]), tensor([[0, 4, 3, 4, 0], [1, 1, 4, 0, 4], [4, 1, 0, 3, 3], [1, 2, 1, 4, 0], [3, 3, 0, 3, 3]])) >>> print(tuple(torch.randn(5,5,10).max(1))) (tensor([[-0.1232, 0.8275, 0.6732, 1.1223, 0.8247, 1.2851, 1.6009, 1.9979, 1.9109, 0.7313], [ 0.2260, 0.5922, 1.6928, 0.6024, 2.1158, 3.0619, 0.5653, 0.7426, 0.8316, 0.6346], [ 0.4319, 0.2231, 0.5255, 1.7620, 1.1657, 0.8875, 0.5782, 0.6506, 0.5032, 1.7097], [ 0.4137, 1.7265, 1.4260, 2.0301, 1.2244, 0.7128, 2.6345, 0.7230, 1.3553, 1.6508], [ 1.0684, 1.7195, 1.4068, 0.7076, -0.0242, 0.8474, 0.8754, 1.7108, 0.2188, 1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0], [1, 4, 0, 0, 3, 2, 0, 0, 3, 3], [2, 3, 1, 1, 4, 0, 1, 4, 4, 4], [0, 4, 1, 3, 2, 0, 2, 0, 3, 1], [1, 0, 0, 0, 0, 3, 3, 3, 2, 0]])) >>> print(tuple(torch.randn(5,5,500).max(1))) (tensor([[0.9395, 1.5572, 1.8797, ..., 2.0494, 0.8202, 0.9623], [1.7937, 0.7225, 1.8836, ..., 0.7927, 1.4976, 1.1813], [0.8558, 1.6943, 1.4192, ..., 0.8327, 1.9661, 0.4197], [1.2993, 1.4995, 0.9357, ..., 0.7810, 1.3030, 2.6216], [1.4206, 1.8315, 1.0338, ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3, ..., 3, 0, 2], [0, 1, 0, ..., 0, 4, 3], [3, 4, 3, ..., 3, 0, 0], [3, 2, 3, ..., 1, 2, 1], [1, 2, 4, ..., 3, 1, 3]])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136 Differential Revision: D14250021 Pulled By: VitalyFedyunin fbshipit-source-id: aae72f03b35980063b1ac1f07b8353eddb0c8b93 |