Add missing device to namedtensor tests (#166717)

This PR passes unused `device` argument to tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166717
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-10-31 20:04:38 +00:00 committed by PyTorch MergeBot
parent ef8d97efcf
commit 856a7a5298

View File

@ -1002,7 +1002,7 @@ class TestNamedTensor(TestCase):
def test_ops(op):
for device in get_all_device_types():
names = ('N', 'D')
tensor = torch.rand(2, 3, names=names)
tensor = torch.rand(2, 3, names=names, device=device)
result = op(tensor, 0)
self.assertEqual(result[0].names, names)
self.assertEqual(result[1].names, names)
@ -1012,15 +1012,15 @@ class TestNamedTensor(TestCase):
def test_logcumsumexp(self):
for device in get_all_device_types():
names = ('N', 'D')
tensor = torch.rand(2, 3, names=names)
tensor = torch.rand(2, 3, names=names, device=device)
result = torch.logcumsumexp(tensor, 'D')
self.assertEqual(result.names, names)
def test_bitwise_not(self):
for device in get_all_device_types():
names = ('N', 'D')
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device)
result = torch.empty(0, dtype=torch.bool, device=device)
self.assertEqual(tensor.bitwise_not().names, names)
self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
@ -1029,8 +1029,8 @@ class TestNamedTensor(TestCase):
def test_logical_not(self):
for device in get_all_device_types():
names = ('N', 'D')
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device)
result = torch.empty(0, dtype=torch.bool, device=device)
self.assertEqual(tensor.logical_not().names, names)
self.assertEqual(torch.logical_not(tensor, out=result).names, names)
@ -1039,8 +1039,8 @@ class TestNamedTensor(TestCase):
def test_bernoulli(self):
for device in get_all_device_types():
names = ('N', 'D')
tensor = torch.rand(2, 3, names=names)
result = torch.empty(0)
tensor = torch.rand(2, 3, names=names, device=device)
result = torch.empty(0, device=device)
self.assertEqual(tensor.bernoulli().names, names)
torch.bernoulli(tensor, out=result)