pytorch/torch/_dynamo/callback.py
Atul Jangra 6a096a0b96 [PT2] Fix callbacks to account for entire execution in compilation (#141323)
Summary:
In SJD, we register the callbacks to get notified of an active compilation. Using this information, we can basically allow for an increase time for the training loop

The callbacks currently do not account for entire time and in several cases, the end callback is not called at all.

This leads to a bunch of APS jobs getting terminated incorrectly: https://fburl.com/scuba/mast_hpc_job_run_status/ondwzt2w

In this diff, we basically install a context manager which will call the start and end callbacks, similar to how we log counters and other information.

Test Plan:
```
buck2 run mode/opt //aps_models/examples/dlrm:dlrm_train_app -- --config-name train_mast_fsdp_torchdynamo launcher.data_project=apf_ai_infra launcher.fbl_entitlement=ai_infra_training_rnd_tc  launcher.hardware=TC_ANY_80G
```
Led to https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-atuljangra-ef2285ba9a?job_attempt=0&version=0&env=prod

https://fburl.com/ai_infra/sv0a213y confirms that callback was correctly called and a lease was properly installed, which takes over the training loop lease.

{F1965137027}

Differential Revision: D66347023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141323
Approved by: https://github.com/ezyang
2024-11-24 22:31:04 +00:00

101 lines
2.9 KiB
Python

from contextlib import contextmanager
from dataclasses import dataclass, field # noqa: F811
from typing import Any, Callable, Generator, List
@dataclass
class CompilationCallbackHandler:
start_callbacks: List[Callable[[], None]] = field(default_factory=list)
end_callbacks: List[Callable[[], None]] = field(default_factory=list)
def register_start_callback(
self, callback: Callable[[], None]
) -> Callable[[], None]:
"""
Register a callback function to be called when the compilation starts.
Args:
- callback (Callable): The callback function to register.
"""
self.start_callbacks.append(callback)
return callback
def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], None]:
"""
Register a callback function to be called when the compilation ends.
Args:
- callback (Callable): The callback function to register.
"""
self.end_callbacks.append(callback)
return callback
def remove_start_callback(self, callback: Callable[[], None]) -> None:
"""
Remove a registered start callback function.
Args:
- callback (Callable): The callback function to remove.
"""
self.start_callbacks.remove(callback)
def remove_end_callback(self, callback: Callable[[], None]) -> None:
"""
Remove a registered end callback function.
Args:
- callback (Callable): The callback function to remove.
"""
self.end_callbacks.remove(callback)
def run_start_callbacks(self) -> None:
"""
Execute all registered start callbacks.
"""
for callback in self.start_callbacks:
callback()
def run_end_callbacks(self) -> None:
"""
Execute all registered end callbacks.
"""
for callback in self.end_callbacks:
callback()
@contextmanager
def install_callbacks(self) -> Generator[None, Any, Any]:
"""
Context manager to install the callbacks and run them when the context is exited.
"""
try:
self.run_start_callbacks()
yield
finally:
self.run_end_callbacks()
def clear(self) -> None:
"""
Clear all registered callbacks.
"""
self.start_callbacks.clear()
self.end_callbacks.clear()
callback_handler = CompilationCallbackHandler()
def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]:
"""
Decorator to register a callback function for the start of the compilation.
"""
callback_handler.register_start_callback(callback)
return callback
def on_compile_end(callback: Callable[[], None]) -> Callable[[], None]:
"""
Decorator to register a callback function for the end of the compilation.
"""
callback_handler.register_end_callback(callback)
return callback