mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
RFC: https://github.com/pytorch/rfcs/pull/54 First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/ We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core. In the next commits, I do a number of things in this order - Fix a few small issues - Make the tests that this PR adds pass - Bend backwards until lintrunner passes - Remove the optional dependency on `torch_np` and simply rely on the upstreamed code - Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now. Missing from this PR (but not blocking): - Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate. - https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge. All the tests in `tests/torch_np` take about 75s to run. This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211 Approved by: https://github.com/ezyang
184 lines
7.6 KiB
Python
184 lines
7.6 KiB
Python
# Owner(s): ["module: functorch"]
|
|
|
|
"""Adapted from https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/tests/test_ops.py.
|
|
|
|
MIT License
|
|
|
|
Copyright (c) 2018 Alex Rogozhnikov
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
"""
|
|
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from functorch.einops import rearrange
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
identity_patterns: List[str] = [
|
|
'...->...',
|
|
'a b c d e-> a b c d e',
|
|
'a b c d e ...-> ... a b c d e',
|
|
'a b c d e ...-> a ... b c d e',
|
|
'... a b c d e -> ... a b c d e',
|
|
'a ... e-> a ... e',
|
|
'a ... -> a ... ',
|
|
'a ... c d e -> a (...) c d e',
|
|
]
|
|
|
|
equivalent_rearrange_patterns: List[Tuple[str, str]] = [
|
|
('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
|
|
('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
|
|
('a b c d e -> a b c d e', '... -> ... '),
|
|
('a b c d e -> (a b c d e)', '... -> (...)'),
|
|
('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
|
|
('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
|
|
]
|
|
|
|
|
|
class TestRearrange(TestCase):
|
|
def test_collapsed_ellipsis_errors_out(self) -> None:
|
|
x = torch.zeros([1, 1, 1, 1, 1])
|
|
rearrange(x, 'a b c d ... -> a b c ... d')
|
|
with self.assertRaises(ValueError):
|
|
rearrange(x, 'a b c d (...) -> a b c ... d')
|
|
|
|
rearrange(x, '... -> (...)')
|
|
with self.assertRaises(ValueError):
|
|
rearrange(x, '(...) -> (...)')
|
|
|
|
def test_ellipsis_ops(self) -> None:
|
|
x = torch.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
|
for pattern in identity_patterns:
|
|
torch.testing.assert_close(rearrange(x, pattern), x, msg=pattern)
|
|
|
|
for pattern1, pattern2 in equivalent_rearrange_patterns:
|
|
torch.testing.assert_close(rearrange(x, pattern1), rearrange(x, pattern2), msg=f"{pattern1} vs {pattern2}")
|
|
|
|
def test_rearrange_consistency(self) -> None:
|
|
shape = [1, 2, 3, 5, 7, 11]
|
|
x = torch.arange(int(np.prod(shape, dtype=int))).reshape(shape)
|
|
for pattern in [
|
|
'a b c d e f -> a b c d e f',
|
|
'b a c d e f -> a b d e f c',
|
|
'a b c d e f -> f e d c b a',
|
|
'a b c d e f -> (f e) d (c b a)',
|
|
'a b c d e f -> (f e d c b a)',
|
|
]:
|
|
result = rearrange(x, pattern)
|
|
self.assertEqual(len(np.setdiff1d(x, result)), 0)
|
|
self.assertIs(result.dtype, x.dtype)
|
|
|
|
result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
|
|
torch.testing.assert_close(x.flatten(), result.flatten())
|
|
|
|
result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
|
|
torch.testing.assert_close(x, result)
|
|
|
|
result1 = rearrange(x, 'a b c d e f -> f e d c b a')
|
|
result2 = rearrange(x, 'f e d c b a -> a b c d e f')
|
|
torch.testing.assert_close(result1, result2)
|
|
|
|
result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
|
|
torch.testing.assert_close(x, result)
|
|
|
|
sizes = dict(zip('abcdef', shape))
|
|
temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
|
|
result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
|
|
torch.testing.assert_close(x, result)
|
|
|
|
x2 = torch.arange(2 * 3 * 4).reshape([2, 3, 4])
|
|
result = rearrange(x2, 'a b c -> b c a')
|
|
self.assertEqual(x2[1, 2, 3], result[2, 3, 1])
|
|
self.assertEqual(x2[0, 1, 2], result[1, 2, 0])
|
|
|
|
def test_rearrange_permutations(self) -> None:
|
|
# tests random permutation of axes against two independent numpy ways
|
|
for n_axes in range(1, 10):
|
|
input = torch.arange(2 ** n_axes).reshape([2] * n_axes)
|
|
permutation = np.random.permutation(n_axes)
|
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
|
|
right_expression = ' '.join('i' + str(axis) for axis in permutation)
|
|
expression = left_expression + ' -> ' + right_expression
|
|
result = rearrange(input, expression)
|
|
|
|
for pick in np.random.randint(0, 2, [10, n_axes]):
|
|
self.assertEqual(input[tuple(pick)], result[tuple(pick[permutation])])
|
|
|
|
for n_axes in range(1, 10):
|
|
input = torch.arange(2 ** n_axes).reshape([2] * n_axes)
|
|
permutation = np.random.permutation(n_axes)
|
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
|
|
right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
|
|
expression = left_expression + ' -> ' + right_expression
|
|
result = rearrange(input, expression)
|
|
self.assertEqual(result.shape, input.shape)
|
|
expected_result = torch.zeros_like(input)
|
|
for original_axis, result_axis in enumerate(permutation):
|
|
expected_result |= ((input >> original_axis) & 1) << result_axis
|
|
|
|
torch.testing.assert_close(result, expected_result)
|
|
|
|
def test_concatenations_and_stacking(self) -> None:
|
|
for n_arrays in [1, 2, 5]:
|
|
shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
|
for shape in shapes:
|
|
arrays1 = [torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape) for i in range(n_arrays)]
|
|
result0 = torch.stack(arrays1)
|
|
result1 = rearrange(arrays1, '...->...')
|
|
torch.testing.assert_close(result0, result1)
|
|
|
|
def test_unsqueeze(self) -> None:
|
|
x = torch.randn((2, 3, 4, 5))
|
|
actual = rearrange(x, 'b h w c -> b 1 h w 1 c')
|
|
expected = x.unsqueeze(1).unsqueeze(-2)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_squeeze(self) -> None:
|
|
x = torch.randn((2, 1, 3, 4, 1, 5))
|
|
actual = rearrange(x, 'b 1 h w 1 c -> b h w c')
|
|
expected = x.squeeze()
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_0_dim_tensor(self) -> None:
|
|
x = expected = torch.tensor(1)
|
|
actual = rearrange(x, '->')
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
actual = rearrange(x, '... -> ...')
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_dimension_mismatch_no_ellipsis(self) -> None:
|
|
x = torch.randn((1, 2, 3))
|
|
with self.assertRaises(ValueError):
|
|
rearrange(x, "a b -> b a")
|
|
|
|
with self.assertRaises(ValueError):
|
|
rearrange(x, "a b c d -> c d b a")
|
|
|
|
def test_dimension_mismatch_with_ellipsis(self) -> None:
|
|
x = torch.tensor(1)
|
|
with self.assertRaises(ValueError):
|
|
rearrange(x, "a ... -> ... a")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|