Clean-up test_indexing.py after Tensor/Variable merge (#6433)

This commit is contained in:
Sam Gross 2018-04-10 14:03:14 -04:00 committed by GitHub
parent aea31131e5
commit 64e94814da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,88 +1,87 @@
from common import TestCase, run_tests from common import TestCase, run_tests
import unittest
import torch import torch
import warnings import warnings
from torch.autograd import Variable from torch import tensor
class TestIndexing(TestCase): class TestIndexing(TestCase):
def test_single_int(self): def test_single_int(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
self.assertEqual(v[4].shape, (7, 3)) self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self): def test_multiple_int(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
self.assertEqual(v[4].shape, (7, 3)) self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,)) self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self): def test_none(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
self.assertEqual(v[None].shape, (1, 5, 7, 3)) self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3)) self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1)) self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self): def test_step(self):
v = Variable(torch.arange(10)) v = torch.arange(10)
self.assertEqual(v[::1], v) self.assertEqual(v[::1], v)
self.assertEqual(v[::2].data.tolist(), [0, 2, 4, 6, 8]) self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].data.tolist(), [0, 3, 6, 9]) self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].data.tolist(), [0]) self.assertEqual(v[::11].tolist(), [0])
self.assertEqual(v[1:6:2].data.tolist(), [1, 3, 5]) self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
def test_step_assignment(self): def test_step_assignment(self):
v = Variable(torch.zeros(4, 4)) v = torch.zeros(4, 4)
v[0, 1::2] = Variable(torch.Tensor([3, 4])) v[0, 1::2] = torch.tensor([3., 4.])
self.assertEqual(v[0].data.tolist(), [0, 3, 0, 4]) self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].data.sum(), 0) self.assertEqual(v[1:].sum(), 0)
def test_byte_mask(self): def test_byte_mask(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
mask = Variable(torch.ByteTensor([1, 0, 1, 1, 0])) mask = torch.ByteTensor([1, 0, 1, 1, 0])
self.assertEqual(v[mask].shape, (3, 7, 3)) self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]])) self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
v = Variable(torch.Tensor([1])) v = torch.tensor([1.])
self.assertEqual(v[v == 0], Variable(torch.Tensor())) self.assertEqual(v[v == 0], torch.tensor([]))
def test_multiple_byte_mask(self): def test_multiple_byte_mask(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
# note: these broadcast together and are transposed to the first dim # note: these broadcast together and are transposed to the first dim
mask1 = Variable(torch.ByteTensor([1, 0, 1, 1, 0])) mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
mask2 = Variable(torch.ByteTensor([1, 1, 1])) mask2 = torch.ByteTensor([1, 1, 1])
self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask2d(self): def test_byte_mask2d(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
c = Variable(torch.randn(5, 7)) c = torch.randn(5, 7)
num_ones = (c > 0).data.sum() num_ones = (c > 0).sum()
r = v[c > 0] r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3)) self.assertEqual(r.shape, (num_ones, 3))
def test_int_indices(self): def test_int_indices(self):
v = Variable(torch.randn(5, 7, 3)) v = torch.randn(5, 7, 3)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
def test_int_indices2d(self): def test_int_indices2d(self):
# From the NumPy indexing example # From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3)) x = torch.arange(0, 12).view(4, 3)
rows = Variable(torch.LongTensor([[0, 0], [3, 3]])) rows = torch.tensor([[0, 0], [3, 3]])
columns = Variable(torch.LongTensor([[0, 2], [0, 2]])) columns = torch.tensor([[0, 2], [0, 2]])
self.assertEqual(x[rows, columns].data.tolist(), [[0, 2], [9, 11]]) self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self): def test_int_indices_broadcast(self):
# From the NumPy indexing example # From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3)) x = torch.arange(0, 12).view(4, 3)
rows = Variable(torch.LongTensor([0, 3])) rows = torch.tensor([0, 3])
columns = Variable(torch.LongTensor([0, 2])) columns = torch.tensor([0, 2])
result = x[rows[:, None], columns] result = x[rows[:, None], columns]
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]]) self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self): def test_empty_index(self):
x = Variable(torch.arange(0, 12).view(4, 3)) x = torch.arange(0, 12).view(4, 3)
idx = Variable(torch.LongTensor()) idx = torch.tensor([], dtype=torch.long)
self.assertEqual(x[idx].numel(), 0) self.assertEqual(x[idx].numel(), 0)
# empty assignment should have no effect but not throw an exception # empty assignment should have no effect but not throw an exception
@ -98,7 +97,7 @@ class TestIndexing(TestCase):
true = torch.tensor(1, dtype=torch.uint8) true = torch.tensor(1, dtype=torch.uint8)
false = torch.tensor(0, dtype=torch.uint8) false = torch.tensor(0, dtype=torch.uint8)
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)] tensors = [torch.randn(2, 3), torch.tensor(3)]
for a in tensors: for a in tensors:
self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
@ -112,7 +111,7 @@ class TestIndexing(TestCase):
true = torch.tensor(1, dtype=torch.uint8) true = torch.tensor(1, dtype=torch.uint8)
false = torch.tensor(0, dtype=torch.uint8) false = torch.tensor(0, dtype=torch.uint8)
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)] tensors = [torch.randn(2, 3), torch.tensor(3)]
for a in tensors: for a in tensors:
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
@ -136,21 +135,21 @@ class TestIndexing(TestCase):
a[:] = neg_ones_expanded * 5 a[:] = neg_ones_expanded * 5
def test_setitem_expansion_error(self): def test_setitem_expansion_error(self):
true = torch.tensor(1, dtype=torch.uint8) true = torch.tensor(True)
a = Variable(torch.randn(2, 3)) a = torch.randn(2, 3)
# check prefix with non-1s doesn't work # check prefix with non-1s doesn't work
a_expanded = a.expand(torch.Size([5, 1]) + a.size()) a_expanded = a.expand(torch.Size([5, 1]) + a.size())
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
a[True] = a_expanded a[True] = a_expanded
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
a[true] = torch.autograd.Variable(a_expanded) a[true] = a_expanded
def test_getitem_scalars(self): def test_getitem_scalars(self):
zero = torch.tensor(0, dtype=torch.int64) zero = torch.tensor(0, dtype=torch.int64)
one = torch.tensor(1, dtype=torch.int64) one = torch.tensor(1, dtype=torch.int64)
# non-scalar indexed with scalars # non-scalar indexed with scalars
a = Variable(torch.randn(2, 3)) a = torch.randn(2, 3)
self.assertEqual(a[0], a[zero]) self.assertEqual(a[0], a[zero])
self.assertEqual(a[0][1], a[zero][one]) self.assertEqual(a[0][1], a[zero][one])
self.assertEqual(a[0, 1], a[zero, one]) self.assertEqual(a[0, 1], a[zero, one])
@ -173,10 +172,10 @@ class TestIndexing(TestCase):
zero = torch.tensor(0, dtype=torch.int64) zero = torch.tensor(0, dtype=torch.int64)
# non-scalar indexed with scalars # non-scalar indexed with scalars
a = Variable(torch.randn(2, 3)) a = torch.randn(2, 3)
a_set_with_number = a.clone() a_set_with_number = a.clone()
a_set_with_scalar = a.clone() a_set_with_scalar = a.clone()
b = Variable(torch.randn(3)) b = torch.randn(3)
a_set_with_number[0] = b a_set_with_number[0] = b
a_set_with_scalar[zero] = b a_set_with_scalar[zero] = b
@ -195,9 +194,9 @@ class TestIndexing(TestCase):
def test_basic_advanced_combined(self): def test_basic_advanced_combined(self):
# From the NumPy indexing example # From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3)) x = torch.arange(0, 12).view(4, 3)
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]]) self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].data.tolist(), [[4, 5]]) self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
# Check that it is a copy # Check that it is a copy
unmodified = x.clone() unmodified = x.clone()
@ -210,33 +209,33 @@ class TestIndexing(TestCase):
self.assertNotEqual(x, unmodified) self.assertNotEqual(x, unmodified)
def test_int_assignment(self): def test_int_assignment(self):
x = Variable(torch.arange(0, 4).view(2, 2)) x = torch.arange(0, 4).view(2, 2)
x[1] = 5 x[1] = 5
self.assertEqual(x.data.tolist(), [[0, 1], [5, 5]]) self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
x = Variable(torch.arange(0, 4).view(2, 2)) x = torch.arange(0, 4).view(2, 2)
x[1] = Variable(torch.arange(5, 7)) x[1] = torch.arange(5, 7)
self.assertEqual(x.data.tolist(), [[0, 1], [5, 6]]) self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self): def test_byte_tensor_assignment(self):
x = Variable(torch.arange(0, 16).view(4, 4)) x = torch.arange(0, 16).view(4, 4)
b = Variable(torch.ByteTensor([True, False, True, False])) b = torch.ByteTensor([True, False, True, False])
value = Variable(torch.Tensor([3, 4, 5, 6])) value = torch.tensor([3., 4., 5., 6.])
x[b] = value x[b] = value
self.assertEqual(x[0], value) self.assertEqual(x[0], value)
self.assertEqual(x[1].data, torch.arange(4, 8)) self.assertEqual(x[1], torch.arange(4, 8))
self.assertEqual(x[2], value) self.assertEqual(x[2], value)
self.assertEqual(x[3].data, torch.arange(12, 16)) self.assertEqual(x[3], torch.arange(12, 16))
def test_variable_slicing(self): def test_variable_slicing(self):
x = Variable(torch.arange(0, 16).view(4, 4)) x = torch.arange(0, 16).view(4, 4)
indices = Variable(torch.IntTensor([0, 1])) indices = torch.IntTensor([0, 1])
i, j = indices i, j = indices
self.assertEqual(x[i:j], x[0:1]) self.assertEqual(x[i:j], x[0:1])
def test_ellipsis_tensor(self): def test_ellipsis_tensor(self):
x = Variable(torch.arange(0, 9).view(3, 3)) x = torch.arange(0, 9).view(3, 3)
idx = Variable(torch.LongTensor([0, 2])) idx = torch.tensor([0, 2])
self.assertEqual(x[..., idx].tolist(), [[0, 2], self.assertEqual(x[..., idx].tolist(), [[0, 2],
[3, 5], [3, 5],
[6, 8]]) [6, 8]])
@ -244,7 +243,7 @@ class TestIndexing(TestCase):
[6, 7, 8]]) [6, 7, 8]])
def test_invalid_index(self): def test_invalid_index(self):
x = Variable(torch.arange(0, 16).view(4, 4)) x = torch.arange(0, 16).view(4, 4)
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"]) self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
def test_zero_dim_index(self): def test_zero_dim_index(self):
@ -256,22 +255,6 @@ class TestIndexing(TestCase):
self.assertEqual(len(w), 1) self.assertEqual(len(w), 1)
def tensor(*args, **kwargs):
return Variable(torch.Tensor(*args, **kwargs))
def byteTensor(data):
return Variable(torch.ByteTensor(data))
def ones(*args):
return Variable(torch.ones(*args))
def zeros(*args):
return Variable(torch.zeros(*args))
# The tests below are from NumPy test_indexing.py with some modifications to # The tests below are from NumPy test_indexing.py with some modifications to
# make them compatible with PyTorch. It's licensed under the BDS license below: # make them compatible with PyTorch. It's licensed under the BDS license below:
# #
@ -309,7 +292,7 @@ def zeros(*args):
class NumpyTests(TestCase): class NumpyTests(TestCase):
def test_index_no_floats(self): def test_index_no_floats(self):
a = Variable(torch.Tensor([[[5]]])) a = torch.tensor([[[5.]]])
self.assertRaises(IndexError, lambda: a[0.0]) self.assertRaises(IndexError, lambda: a[0.0])
self.assertRaises(IndexError, lambda: a[0, 0.0]) self.assertRaises(IndexError, lambda: a[0, 0.0])
@ -348,10 +331,10 @@ class NumpyTests(TestCase):
def test_empty_fancy_index(self): def test_empty_fancy_index(self):
# Empty list index creates an empty array # Empty list index creates an empty array
a = tensor([1, 2, 3]) a = tensor([1, 2, 3])
self.assertEqual(a[[]], Variable(torch.Tensor())) self.assertEqual(a[[]], torch.tensor([]))
b = tensor([]).long() b = tensor([]).long()
self.assertEqual(a[[]], Variable(torch.LongTensor())) self.assertEqual(a[[]], torch.tensor([], dtype=torch.long))
b = tensor([]).float() b = tensor([]).float()
self.assertRaises(RuntimeError, lambda: a[b]) self.assertRaises(RuntimeError, lambda: a[b])
@ -386,8 +369,8 @@ class NumpyTests(TestCase):
[4, 5, 6], [4, 5, 6],
[7, 8, 9]]) [7, 8, 9]])
self.assertEqual(a[0].data, [1, 2, 3]) self.assertEqual(a[0], [1, 2, 3])
self.assertEqual(a[-1].data, [7, 8, 9]) self.assertEqual(a[-1], [7, 8, 9])
# Index out of bounds produces IndexError # Index out of bounds produces IndexError
self.assertRaises(IndexError, a.__getitem__, 1 << 30) self.assertRaises(IndexError, a.__getitem__, 1 << 30)
@ -404,16 +387,16 @@ class NumpyTests(TestCase):
self.assertEqual(a[False], a[None][0:0]) self.assertEqual(a[False], a[None][0:0])
def test_boolean_shape_mismatch(self): def test_boolean_shape_mismatch(self):
arr = ones((5, 4, 3)) arr = torch.ones((5, 4, 3))
# TODO: prefer IndexError # TODO: prefer IndexError
index = byteTensor([True]) index = tensor([True])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index]) self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = byteTensor([False] * 6) index = tensor([False] * 6)
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index]) self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = Variable(torch.ByteTensor(4, 4)).zero_() index = torch.ByteTensor(4, 4).zero_()
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index]) self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)]) self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)])
@ -422,7 +405,7 @@ class NumpyTests(TestCase):
# Indexing a 2-dimensional array with # Indexing a 2-dimensional array with
# boolean array of length one # boolean array of length one
a = tensor([[0., 0., 0.]]) a = tensor([[0., 0., 0.]])
b = byteTensor([True]) b = tensor([True])
self.assertEqual(a[b], a) self.assertEqual(a[b], a)
# boolean assignment # boolean assignment
a[b] = 1. a[b] = 1.
@ -431,7 +414,7 @@ class NumpyTests(TestCase):
def test_boolean_assignment_value_mismatch(self): def test_boolean_assignment_value_mismatch(self):
# A boolean assignment should fail when the shape of the values # A boolean assignment should fail when the shape of the values
# cannot be broadcast to the subscription. (see also gh-3458) # cannot be broadcast to the subscription. (see also gh-3458)
a = Variable(torch.arange(0, 4)) a = torch.arange(0, 4)
def f(a, v): def f(a, v):
a[a > -1] = tensor(v) a[a > -1] = tensor(v)
@ -446,9 +429,9 @@ class NumpyTests(TestCase):
a = tensor([[1, 2, 3], a = tensor([[1, 2, 3],
[4, 5, 6], [4, 5, 6],
[7, 8, 9]]) [7, 8, 9]])
b = byteTensor([[True, False, True], b = tensor([[True, False, True],
[False, True, False], [False, True, False],
[True, False, True]]) [True, False, True]])
self.assertEqual(a[b], tensor([1, 3, 5, 7, 9])) self.assertEqual(a[b], tensor([1, 3, 5, 7, 9]))
self.assertEqual(a[b[1]], tensor([[4, 5, 6]])) self.assertEqual(a[b[1]], tensor([[4, 5, 6]]))
self.assertEqual(a[b[0]], a[b[2]]) self.assertEqual(a[b[0]], a[b[2]])
@ -461,39 +444,39 @@ class NumpyTests(TestCase):
def test_everything_returns_views(self): def test_everything_returns_views(self):
# Before `...` would return a itself. # Before `...` would return a itself.
a = tensor(5) a = tensor([5])
self.assertIsNot(a, a[()]) self.assertIsNot(a, a[()])
self.assertIsNot(a, a[...]) self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:]) self.assertIsNot(a, a[:])
def test_broaderrors_indexing(self): def test_broaderrors_indexing(self):
a = zeros(5, 5) a = torch.zeros(5, 5)
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2])) self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0) self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
def test_trivial_fancy_out_of_bounds(self): def test_trivial_fancy_out_of_bounds(self):
a = zeros(5) a = torch.zeros(5)
ind = ones(20).long() ind = torch.ones(20, dtype=torch.int64)
ind[-1] = 10 ind[-1] = 10
self.assertRaises(RuntimeError, a.__getitem__, ind) self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0) self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
ind = ones(20).long() ind = torch.ones(20, dtype=torch.int64)
ind[0] = 11 ind[0] = 11
self.assertRaises(RuntimeError, a.__getitem__, ind) self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0) self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
def test_index_is_larger(self): def test_index_is_larger(self):
# Simple case of fancy index broadcasting of the index. # Simple case of fancy index broadcasting of the index.
a = zeros((5, 5)) a = torch.zeros((5, 5))
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2, 3, 4]) a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.])
self.assertTrue((a[:3, :3] == tensor([2, 3, 4])).all()) self.assertTrue((a[:3, :3] == tensor([2., 3., 4.])).all())
def test_broadcast_subspace(self): def test_broadcast_subspace(self):
a = zeros((100, 100)) a = torch.zeros((100, 100))
v = Variable(torch.arange(0, 100))[:, None] v = torch.arange(0, 100)[:, None]
b = Variable(torch.arange(99, -1, -1).long()) b = torch.arange(99, -1, -1).long()
a[b] = v a[b] = v
expected = b.double().unsqueeze(1).expand(100, 100) expected = b.double().unsqueeze(1).expand(100, 100)
self.assertEqual(a, expected) self.assertEqual(a, expected)