diff --git a/test/test_foreach.py b/test/test_foreach.py index 6b7d1fcdc3e..17f09fb61be 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -156,7 +156,7 @@ class TestForeach(TestCase): torch._foreach_add_(tensors1, tensors2) # different devices - if torch.cuda.is_available(): + if torch.cuda.is_available() and torch.cuda.device_count() > 1: tensor1 = torch.zeros(10, 10, device="cuda:0") tensor2 = torch.ones(10, 10, device="cuda:1") with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):