pytorch/torch/_dynamo/callback.py
Yanbo Liang 169c220bf8 [torch.compile] Provide capability to register callback on compile start/stop (#120764)
This is a requirement from Meta internal cases, where ppl wants to register a callback function to detect if a job is stuck during compilation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120764
Approved by: https://github.com/jansel
2024-02-29 07:37:52 +00:00

83 lines
2.1 KiB
Python

class CompilationCallbackHandler:
def __init__(self):
self.start_callbacks = []
self.end_callbacks = []
def register_start_callback(self, callback):
"""
Register a callback function to be called when the compilation starts.
Args:
- callback (callable): The callback function to register.
"""
self.start_callbacks.append(callback)
return callback
def register_end_callback(self, callback):
"""
Register a callback function to be called when the compilation ends.
Args:
- callback (callable): The callback function to register.
"""
self.end_callbacks.append(callback)
return callback
def remove_start_callback(self, callback):
"""
Remove a registered start callback function.
Args:
- callback (callable): The callback function to remove.
"""
self.start_callbacks.remove(callback)
def remove_end_callback(self, callback):
"""
Remove a registered end callback function.
Args:
- callback (callable): The callback function to remove.
"""
self.end_callbacks.remove(callback)
def run_start_callbacks(self):
"""
Execute all registered start callbacks.
"""
for callback in self.start_callbacks:
callback()
def run_end_callbacks(self):
"""
Execute all registered end callbacks.
"""
for callback in self.end_callbacks:
callback()
def clear(self):
"""
Clear all registered callbacks.
"""
self.start_callbacks.clear()
self.end_callbacks.clear()
callback_handler = CompilationCallbackHandler()
def on_compile_start(callback):
"""
Decorator to register a callback function for the start of the compilation.
"""
callback_handler.register_start_callback(callback)
return callback
def on_compile_end(callback):
"""
Decorator to register a callback function for the end of the compilation.
"""
callback_handler.register_end_callback(callback)
return callback