pytorch/torch/_dynamo/callback.py
Burak Turk 01a4d86b31 add pt2 callbacks for backward pass and prevent duplicate callbacks (#145732)
Summary: This change adds callbacks for lazy backwards compilation while preventing duplicate callbacks to be fired.

Differential Revision: D68577593

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145732
Approved by: https://github.com/mlazos
2025-01-28 03:50:02 +00:00

137 lines
4.4 KiB
Python

import threading
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field # noqa: F811
from typing import Any, Callable, Optional
from torch._utils_internal import justknobs_check
@dataclass
class CompilationCallbackHandler:
start_callbacks: list[Callable[[], None]] = field(default_factory=list)
end_callbacks: list[Callable[[], None]] = field(default_factory=list)
__prevent_duplicate_callbacks: Optional[bool] = field(
default=None, init=False, repr=False
)
__pending_callbacks_counter: int = field(default=0, init=False, repr=False)
__pending_callbacks_counter_lock: threading.Lock = field(
default_factory=threading.Lock, init=False, repr=False
)
def register_start_callback(
self, callback: Callable[[], None]
) -> Callable[[], None]:
"""
Register a callback function to be called when the compilation starts.
Args:
- callback (Callable): The callback function to register.
"""
self.start_callbacks.append(callback)
return callback
def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], None]:
"""
Register a callback function to be called when the compilation ends.
Args:
- callback (Callable): The callback function to register.
"""
self.end_callbacks.append(callback)
return callback
def remove_start_callback(self, callback: Callable[[], None]) -> None:
"""
Remove a registered start callback function.
Args:
- callback (Callable): The callback function to remove.
"""
self.start_callbacks.remove(callback)
def remove_end_callback(self, callback: Callable[[], None]) -> None:
"""
Remove a registered end callback function.
Args:
- callback (Callable): The callback function to remove.
"""
self.end_callbacks.remove(callback)
def run_start_callbacks(self) -> None:
"""
Execute all registered start callbacks.
"""
for callback in self.start_callbacks:
callback()
def run_end_callbacks(self) -> None:
"""
Execute all registered end callbacks.
"""
for callback in self.end_callbacks:
callback()
@property
def prevent_duplicate_callbacks(self) -> bool:
if self.__prevent_duplicate_callbacks is None:
self.__prevent_duplicate_callbacks = justknobs_check(
"pytorch/dynamo:prevent_duplicate_callbacks"
)
return self.__prevent_duplicate_callbacks
@contextmanager
def install_callbacks(self) -> Generator[None, Any, Any]:
"""
Context manager to install the callbacks and run them when the context is exited.
"""
if self.prevent_duplicate_callbacks:
try:
with self.__pending_callbacks_counter_lock:
if self.__pending_callbacks_counter == 0:
self.run_start_callbacks()
self.__pending_callbacks_counter += 1
yield
finally:
with self.__pending_callbacks_counter_lock:
assert (
self.__pending_callbacks_counter > 0
), "Pending callbacks counter cannot become negative."
if self.__pending_callbacks_counter == 1:
self.run_end_callbacks()
self.__pending_callbacks_counter -= 1
else:
try:
self.run_start_callbacks()
yield
finally:
self.run_end_callbacks()
def clear(self) -> None:
"""
Clear all registered callbacks.
"""
self.start_callbacks.clear()
self.end_callbacks.clear()
callback_handler = CompilationCallbackHandler()
def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]:
"""
Decorator to register a callback function for the start of the compilation.
"""
callback_handler.register_start_callback(callback)
return callback
def on_compile_end(callback: Callable[[], None]) -> Callable[[], None]:
"""
Decorator to register a callback function for the end of the compilation.
"""
callback_handler.register_end_callback(callback)
return callback