mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: In SJD, we register the callbacks to get notified of an active compilation. Using this information, we can basically allow for an increase time for the training loop The callbacks currently do not account for entire time and in several cases, the end callback is not called at all. This leads to a bunch of APS jobs getting terminated incorrectly: https://fburl.com/scuba/mast_hpc_job_run_status/ondwzt2w In this diff, we basically install a context manager which will call the start and end callbacks, similar to how we log counters and other information. Test Plan: ``` buck2 run mode/opt //aps_models/examples/dlrm:dlrm_train_app -- --config-name train_mast_fsdp_torchdynamo launcher.data_project=apf_ai_infra launcher.fbl_entitlement=ai_infra_training_rnd_tc launcher.hardware=TC_ANY_80G ``` Led to https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-atuljangra-ef2285ba9a?job_attempt=0&version=0&env=prod https://fburl.com/ai_infra/sv0a213y confirms that callback was correctly called and a lease was properly installed, which takes over the training loop lease. {F1965137027} Differential Revision: D66347023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141323 Approved by: https://github.com/ezyang
101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field # noqa: F811
|
|
from typing import Any, Callable, Generator, List
|
|
|
|
|
|
@dataclass
|
|
class CompilationCallbackHandler:
|
|
start_callbacks: List[Callable[[], None]] = field(default_factory=list)
|
|
end_callbacks: List[Callable[[], None]] = field(default_factory=list)
|
|
|
|
def register_start_callback(
|
|
self, callback: Callable[[], None]
|
|
) -> Callable[[], None]:
|
|
"""
|
|
Register a callback function to be called when the compilation starts.
|
|
|
|
Args:
|
|
- callback (Callable): The callback function to register.
|
|
"""
|
|
self.start_callbacks.append(callback)
|
|
return callback
|
|
|
|
def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], None]:
|
|
"""
|
|
Register a callback function to be called when the compilation ends.
|
|
|
|
Args:
|
|
- callback (Callable): The callback function to register.
|
|
"""
|
|
self.end_callbacks.append(callback)
|
|
return callback
|
|
|
|
def remove_start_callback(self, callback: Callable[[], None]) -> None:
|
|
"""
|
|
Remove a registered start callback function.
|
|
|
|
Args:
|
|
- callback (Callable): The callback function to remove.
|
|
"""
|
|
self.start_callbacks.remove(callback)
|
|
|
|
def remove_end_callback(self, callback: Callable[[], None]) -> None:
|
|
"""
|
|
Remove a registered end callback function.
|
|
|
|
Args:
|
|
- callback (Callable): The callback function to remove.
|
|
"""
|
|
self.end_callbacks.remove(callback)
|
|
|
|
def run_start_callbacks(self) -> None:
|
|
"""
|
|
Execute all registered start callbacks.
|
|
"""
|
|
for callback in self.start_callbacks:
|
|
callback()
|
|
|
|
def run_end_callbacks(self) -> None:
|
|
"""
|
|
Execute all registered end callbacks.
|
|
"""
|
|
for callback in self.end_callbacks:
|
|
callback()
|
|
|
|
@contextmanager
|
|
def install_callbacks(self) -> Generator[None, Any, Any]:
|
|
"""
|
|
Context manager to install the callbacks and run them when the context is exited.
|
|
"""
|
|
try:
|
|
self.run_start_callbacks()
|
|
yield
|
|
finally:
|
|
self.run_end_callbacks()
|
|
|
|
def clear(self) -> None:
|
|
"""
|
|
Clear all registered callbacks.
|
|
"""
|
|
self.start_callbacks.clear()
|
|
self.end_callbacks.clear()
|
|
|
|
|
|
callback_handler = CompilationCallbackHandler()
|
|
|
|
|
|
def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]:
|
|
"""
|
|
Decorator to register a callback function for the start of the compilation.
|
|
"""
|
|
callback_handler.register_start_callback(callback)
|
|
return callback
|
|
|
|
|
|
def on_compile_end(callback: Callable[[], None]) -> Callable[[], None]:
|
|
"""
|
|
Decorator to register a callback function for the end of the compilation.
|
|
"""
|
|
callback_handler.register_end_callback(callback)
|
|
return callback
|