compile_worker: Make a timer class (#166465)

This subclass allows us to trigger an action after we haven't seen any activity
for a certain amount of seconds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166465
Approved by: https://github.com/masnesral
This commit is contained in:
clr 2025-10-31 13:24:00 -07:00 committed by PyTorch MergeBot
parent 51667435f5
commit d80ae738c9
2 changed files with 109 additions and 0 deletions

View File

@ -2,12 +2,14 @@
import operator
import os
import tempfile
from threading import Event
from torch._inductor.compile_worker.subproc_pool import (
raise_testexc,
SubprocException,
SubprocPool,
)
from torch._inductor.compile_worker.timer import Timer
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.inductor_utils import HAS_CPU
@ -81,6 +83,59 @@ class TestCompileWorker(TestCase):
pool.shutdown()
class TestTimer(TestCase):
def test_basics(self):
done = Event()
def doit():
done.set()
t = Timer(0.1, doit)
t.sleep_time = 0.1
t.record_call()
self.assertTrue(done.wait(4))
t.quit()
def test_repeated_calls(self):
done = Event()
def doit():
done.set()
t = Timer(0.1, doit)
t.sleep_time = 0.1
for i in range(10):
t.record_call()
self.assertTrue(done.wait(4))
done.clear()
t.quit()
def test_never_fires(self):
done = Event()
def doit():
done.set()
t = Timer(999, doit)
t.sleep_time = 0.1
t.record_call()
self.assertFalse(done.wait(4))
t.quit()
def test_spammy_calls(self):
done = Event()
def doit():
done.set()
t = Timer(1, doit)
t.sleep_time = 0.1
for i in range(400):
t.record_call()
self.assertTrue(done.wait(4))
t.quit()
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -0,0 +1,54 @@
from threading import Lock, Thread
from time import monotonic, sleep
from typing import Callable, Optional, Union
class Timer:
"""
This measures how long we have gone since last receiving an event and if it is greater than a set interval, calls a function.
"""
def __init__(
self,
duration: Union[int, float], # Duration in seconds
call: Callable[[], None], # Function to call when we expire
) -> None:
# We don't start the background thread until we actually get an event.
self.background_thread: Optional[Thread] = None
self.last_called: Optional[float] = None
self.duration = duration
self.sleep_time = 60
self.call = call
self.exit = False
self.lock = Lock()
def record_call(self) -> None:
with self.lock:
if self.background_thread is None:
self.background_thread = Thread(
target=self.check, daemon=True, name="subproc_worker_timer"
)
self.background_thread.start()
self.last_called = monotonic()
def quit(self) -> None:
with self.lock:
self.exit = True
def check(self) -> None:
while True:
# We have to be sensitive on checking here, to avoid too much impact on cpu
sleep(self.sleep_time)
with self.lock:
if self.exit:
return
assert self.last_called is not None
if self.last_called + self.duration >= monotonic():
continue
self.last_called = None
self.background_thread = None
# Releasing lock in case self.call() takes a very long time or is reentrant
self.call()
return