# Owner(s): ["module: intel"] import sys import unittest import torch from torch.testing._internal.common_utils import NoTest, run_tests, TEST_XPU, TestCase if not TEST_XPU: print("XPU not available, skipping tests", file=sys.stderr) TestCase = NoTest # noqa: F811 TEST_MULTIXPU = torch.xpu.device_count() > 1 class TestXpu(TestCase): def test_device_behavior(self): current_device = torch.xpu.current_device() torch.xpu.set_device(current_device) self.assertEqual(current_device, torch.xpu.current_device()) @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") def test_multi_device_behavior(self): current_device = torch.xpu.current_device() target_device = (current_device + 1) % torch.xpu.device_count() with torch.xpu.device(target_device): self.assertEqual(target_device, torch.xpu.current_device()) self.assertEqual(current_device, torch.xpu.current_device()) with torch.xpu._DeviceGuard(target_device): self.assertEqual(target_device, torch.xpu.current_device()) self.assertEqual(current_device, torch.xpu.current_device()) def test_get_device_properties(self): current_device = torch.xpu.current_device() device_properties = torch.xpu.get_device_properties(current_device) self.assertEqual(device_properties, torch.xpu.get_device_properties(None)) self.assertEqual(device_properties, torch.xpu.get_device_properties()) device_name = torch.xpu.get_device_name(current_device) self.assertEqual(device_name, torch.xpu.get_device_name(None)) self.assertEqual(device_name, torch.xpu.get_device_name()) device_capability = torch.xpu.get_device_capability(current_device) self.assertTrue(device_capability["max_work_group_size"] > 0) self.assertTrue(device_capability["max_num_sub_groups"] > 0) def test_wrong_xpu_fork(self): stderr = TestCase.runWithPytorchAPIUsageStderr( """\ import torch from torch.multiprocessing import Process def run(rank): torch.xpu.set_device(rank) if __name__ == "__main__": size = 2 processes = [] for rank in range(size): # it would work fine without the line below torch.xpu.set_device(0) p = Process(target=run, args=(rank,)) p.start() processes.append(p) for p in processes: p.join() """ ) self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.") def test_streams(self): s0 = torch.xpu.Stream() torch.xpu.set_stream(s0) s1 = torch.xpu.current_stream() self.assertEqual(s0, s1) s2 = torch.xpu.Stream() self.assertFalse(s0 == s2) torch.xpu.set_stream(s2) with torch.xpu.stream(s0): self.assertEqual(s0, torch.xpu.current_stream()) self.assertEqual(s2, torch.xpu.current_stream()) def test_stream_priority(self): low, high = torch.xpu.Stream.priority_range() s0 = torch.xpu.Stream(device=0, priority=low) self.assertEqual(low, s0.priority) self.assertEqual(torch.device("xpu:0"), s0.device) s1 = torch.xpu.Stream(device=0, priority=high) self.assertEqual(high, s1.priority) self.assertEqual(torch.device("xpu:0"), s1.device) def test_stream_event_repr(self): s = torch.xpu.current_stream() self.assertTrue("torch.xpu.Stream" in str(s)) e = torch.xpu.Event() self.assertTrue("torch.xpu.Event(uninitialized)" in str(e)) s.record_event(e) self.assertTrue("torch.xpu.Event" in str(e)) def test_events(self): stream = torch.xpu.current_stream() event = torch.xpu.Event() self.assertTrue(event.query()) stream.record_event(event) event.synchronize() self.assertTrue(event.query()) def test_generator(self): torch.manual_seed(2024) g_state0 = torch.xpu.get_rng_state() torch.manual_seed(1234) g_state1 = torch.xpu.get_rng_state() self.assertNotEqual(g_state0, g_state1) torch.xpu.manual_seed(2024) g_state2 = torch.xpu.get_rng_state() self.assertEqual(g_state0, g_state2) torch.xpu.set_rng_state(g_state1) self.assertEqual(g_state1, torch.xpu.get_rng_state()) torch.manual_seed(1234) torch.xpu.set_rng_state(g_state0) self.assertEqual(2024, torch.xpu.initial_seed()) if __name__ == "__main__": run_tests()