# Owner(s): ["module: dynamo"] import unittest from unittest.mock import Mock import torch from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger from torch._dynamo.test_case import run_tests, TestCase from torch._guards import CompileId from torch.testing._internal.common_utils import TEST_WITH_ROCM from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_gpu device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) class CallbackTests(TestCase): def setUp(self) -> None: super().setUp() self._on_compile_start = Mock() self._on_compile_end = Mock() callback_handler.register_start_callback(self._on_compile_start) callback_handler.register_end_callback(self._on_compile_end) def tearDown(self) -> None: callback_handler.clear() return super().tearDown() def test_callbacks_with_duplicate_prevention(self) -> None: trigger = CallbackTrigger.DYNAMO compile_id = CompileId(frame_id=0, frame_compile_id=0) with ( callback_handler.install_callbacks(trigger, compile_id), callback_handler.install_callbacks(trigger, compile_id), ): self._on_compile_start.assert_called_once() self._on_compile_end.assert_called_once() def test_counter(self) -> None: trigger = CallbackTrigger.DYNAMO compile_id = CompileId(frame_id=0, frame_compile_id=0) with callback_handler.install_callbacks(trigger, compile_id): self.assertEqual( callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 1, ) self.assertEqual( callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0 ) def test_counter_assertion(self) -> None: callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1 with self.assertRaisesRegex( AssertionError, "Pending callbacks counter cannot become negative." ): trigger = CallbackTrigger.DYNAMO compile_id = CompileId(frame_id=0, frame_compile_id=0) with callback_handler.install_callbacks(trigger, str(compile_id)): pass self.assertEqual( callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0 ) @unittest.skipIf( TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" ) @requires_gpu @torch._inductor.config.patch(force_disable_caches=True) def test_triggers(self) -> None: torch._dynamo.reset() order = [] def on_start(args: CallbackArgs): nonlocal order order.append(f"start={args}") def on_end(args: CallbackArgs): nonlocal order order.append(f"end={args}") torch._dynamo.callback.on_compile_start(on_start) torch._dynamo.callback.on_compile_start(on_end) class TinyModel(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(10, 10) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(10, 10) def forward(self, x): temp = self.fc1(x) temp = self.relu(temp) torch._dynamo.graph_break() return self.fc2(temp) model = TinyModel().to(device_type) compiled_model = torch.compile(model, mode="max-autotune") x = torch.randn(10, 10, device=device_type) loss = compiled_model(x).sum() loss.backward() self.assertExpectedInline( "\n".join(order), """\ start=CallbackArgs(callback_trigger=, compile_id='0/0') end=CallbackArgs(callback_trigger=, compile_id='0/0') start=CallbackArgs(callback_trigger=, compile_id='1/0') end=CallbackArgs(callback_trigger=, compile_id='1/0') start=CallbackArgs(callback_trigger=, compile_id='1/0') end=CallbackArgs(callback_trigger=, compile_id='1/0') start=CallbackArgs(callback_trigger=, compile_id='0/0') end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 ) order.clear() if not HAS_CUDA_AND_TRITON: return compiled_model.zero_grad() loss = compiled_model(x).sum() loss.backward() self.assertExpectedInline( "\n".join(order), """\ start=CallbackArgs(callback_trigger=, compile_id='0/0') end=CallbackArgs(callback_trigger=, compile_id='0/0') start=CallbackArgs(callback_trigger=, compile_id='1/0') end=CallbackArgs(callback_trigger=, compile_id='1/0') start=CallbackArgs(callback_trigger=, compile_id='1/0') end=CallbackArgs(callback_trigger=, compile_id='1/0') start=CallbackArgs(callback_trigger=, compile_id='0/0') end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 ) order.clear() compiled_model.zero_grad() loss = compiled_model(x).sum() loss.backward() self.assertEqual(len(order), 0) if __name__ == "__main__": run_tests()