diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 8fde26c6acf..79f35765533 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -2,12 +2,14 @@ import operator import os import tempfile +from threading import Event from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, SubprocPool, ) +from torch._inductor.compile_worker.timer import Timer from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU @@ -81,6 +83,59 @@ class TestCompileWorker(TestCase): pool.shutdown() +class TestTimer(TestCase): + def test_basics(self): + done = Event() + + def doit(): + done.set() + + t = Timer(0.1, doit) + t.sleep_time = 0.1 + t.record_call() + self.assertTrue(done.wait(4)) + t.quit() + + def test_repeated_calls(self): + done = Event() + + def doit(): + done.set() + + t = Timer(0.1, doit) + t.sleep_time = 0.1 + for i in range(10): + t.record_call() + self.assertTrue(done.wait(4)) + done.clear() + t.quit() + + def test_never_fires(self): + done = Event() + + def doit(): + done.set() + + t = Timer(999, doit) + t.sleep_time = 0.1 + t.record_call() + self.assertFalse(done.wait(4)) + t.quit() + + def test_spammy_calls(self): + done = Event() + + def doit(): + done.set() + + t = Timer(1, doit) + t.sleep_time = 0.1 + for i in range(400): + t.record_call() + self.assertTrue(done.wait(4)) + t.quit() + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py new file mode 100644 index 00000000000..d4b0c0dc9e2 --- /dev/null +++ b/torch/_inductor/compile_worker/timer.py @@ -0,0 +1,54 @@ +from threading import Lock, Thread +from time import monotonic, sleep +from typing import Callable, Optional, Union + + +class Timer: + """ + This measures how long we have gone since last receiving an event and if it is greater than a set interval, calls a function. + """ + + def __init__( + self, + duration: Union[int, float], # Duration in seconds + call: Callable[[], None], # Function to call when we expire + ) -> None: + # We don't start the background thread until we actually get an event. + self.background_thread: Optional[Thread] = None + self.last_called: Optional[float] = None + self.duration = duration + self.sleep_time = 60 + self.call = call + self.exit = False + + self.lock = Lock() + + def record_call(self) -> None: + with self.lock: + if self.background_thread is None: + self.background_thread = Thread( + target=self.check, daemon=True, name="subproc_worker_timer" + ) + self.background_thread.start() + self.last_called = monotonic() + + def quit(self) -> None: + with self.lock: + self.exit = True + + def check(self) -> None: + while True: + # We have to be sensitive on checking here, to avoid too much impact on cpu + sleep(self.sleep_time) + with self.lock: + if self.exit: + return + assert self.last_called is not None + if self.last_called + self.duration >= monotonic(): + continue + self.last_called = None + self.background_thread = None + + # Releasing lock in case self.call() takes a very long time or is reentrant + self.call() + return