mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This change adds callbacks for lazy backwards compilation while preventing duplicate callbacks to be fired. Differential Revision: D68577593 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145732 Approved by: https://github.com/mlazos
68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
from unittest.mock import Mock
|
|
|
|
from torch._dynamo.callback import callback_handler
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
|
|
|
|
class CallbackTests(TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._on_compile_start = Mock()
|
|
self._on_compile_end = Mock()
|
|
callback_handler.register_start_callback(self._on_compile_start)
|
|
callback_handler.register_end_callback(self._on_compile_end)
|
|
|
|
def tearDown(self) -> None:
|
|
return super().tearDown()
|
|
callback_handler.clear()
|
|
|
|
def test_callbacks_without_duplicate_prevention(self) -> None:
|
|
callback_handler._CompilationCallbackHandler__prevent_duplicate_callbacks = (
|
|
False
|
|
)
|
|
|
|
with callback_handler.install_callbacks(), callback_handler.install_callbacks():
|
|
self.assertEqual(self._on_compile_start.call_count, 2)
|
|
self.assertEqual(self._on_compile_end.call_count, 2)
|
|
|
|
def test_callbacks_with_duplicate_prevention(self) -> None:
|
|
callback_handler._CompilationCallbackHandler__prevent_duplicate_callbacks = True
|
|
|
|
with callback_handler.install_callbacks(), callback_handler.install_callbacks():
|
|
self._on_compile_start.assert_called_once()
|
|
self._on_compile_end.assert_called_once()
|
|
|
|
def test_counter(self) -> None:
|
|
callback_handler._CompilationCallbackHandler__prevent_duplicate_callbacks = True
|
|
|
|
with callback_handler.install_callbacks():
|
|
self.assertEqual(
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
|
1,
|
|
)
|
|
self.assertEqual(
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
|
|
)
|
|
|
|
def test_counter_assertion(self) -> None:
|
|
callback_handler._CompilationCallbackHandler__prevent_duplicate_callbacks = True
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1
|
|
|
|
with self.assertRaises(
|
|
AssertionError
|
|
) as e, callback_handler.install_callbacks():
|
|
pass
|
|
|
|
self.assertIn(
|
|
"Pending callbacks counter cannot become negative.",
|
|
str(e.exception),
|
|
)
|
|
|
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter += 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|