mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add missing device to namedtensor tests (#166717)
This PR passes unused `device` argument to tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166717 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
ef8d97efcf
commit
856a7a5298
|
|
@ -1002,7 +1002,7 @@ class TestNamedTensor(TestCase):
|
|||
def test_ops(op):
|
||||
for device in get_all_device_types():
|
||||
names = ('N', 'D')
|
||||
tensor = torch.rand(2, 3, names=names)
|
||||
tensor = torch.rand(2, 3, names=names, device=device)
|
||||
result = op(tensor, 0)
|
||||
self.assertEqual(result[0].names, names)
|
||||
self.assertEqual(result[1].names, names)
|
||||
|
|
@ -1012,15 +1012,15 @@ class TestNamedTensor(TestCase):
|
|||
def test_logcumsumexp(self):
|
||||
for device in get_all_device_types():
|
||||
names = ('N', 'D')
|
||||
tensor = torch.rand(2, 3, names=names)
|
||||
tensor = torch.rand(2, 3, names=names, device=device)
|
||||
result = torch.logcumsumexp(tensor, 'D')
|
||||
self.assertEqual(result.names, names)
|
||||
|
||||
def test_bitwise_not(self):
|
||||
for device in get_all_device_types():
|
||||
names = ('N', 'D')
|
||||
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
|
||||
result = torch.empty(0, dtype=torch.bool)
|
||||
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device)
|
||||
result = torch.empty(0, dtype=torch.bool, device=device)
|
||||
|
||||
self.assertEqual(tensor.bitwise_not().names, names)
|
||||
self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
|
||||
|
|
@ -1029,8 +1029,8 @@ class TestNamedTensor(TestCase):
|
|||
def test_logical_not(self):
|
||||
for device in get_all_device_types():
|
||||
names = ('N', 'D')
|
||||
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
|
||||
result = torch.empty(0, dtype=torch.bool)
|
||||
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device)
|
||||
result = torch.empty(0, dtype=torch.bool, device=device)
|
||||
|
||||
self.assertEqual(tensor.logical_not().names, names)
|
||||
self.assertEqual(torch.logical_not(tensor, out=result).names, names)
|
||||
|
|
@ -1039,8 +1039,8 @@ class TestNamedTensor(TestCase):
|
|||
def test_bernoulli(self):
|
||||
for device in get_all_device_types():
|
||||
names = ('N', 'D')
|
||||
tensor = torch.rand(2, 3, names=names)
|
||||
result = torch.empty(0)
|
||||
tensor = torch.rand(2, 3, names=names, device=device)
|
||||
result = torch.empty(0, device=device)
|
||||
self.assertEqual(tensor.bernoulli().names, names)
|
||||
|
||||
torch.bernoulli(tensor, out=result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user