Summary:
In SJD, we register the callbacks to get notified of an active compilation. Using this information, we can basically allow for an increase time for the training loop
The callbacks currently do not account for entire time and in several cases, the end callback is not called at all.
This leads to a bunch of APS jobs getting terminated incorrectly: https://fburl.com/scuba/mast_hpc_job_run_status/ondwzt2w
In this diff, we basically install a context manager which will call the start and end callbacks, similar to how we log counters and other information.
Test Plan:
```
buck2 run mode/opt //aps_models/examples/dlrm:dlrm_train_app -- --config-name train_mast_fsdp_torchdynamo launcher.data_project=apf_ai_infra launcher.fbl_entitlement=ai_infra_training_rnd_tc launcher.hardware=TC_ANY_80G
```
Led to https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-atuljangra-ef2285ba9a?job_attempt=0&version=0&env=prodhttps://fburl.com/ai_infra/sv0a213y confirms that callback was correctly called and a lease was properly installed, which takes over the training loop lease.
{F1965137027}
Differential Revision: D66347023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141323
Approved by: https://github.com/ezyang