pytorch/test/dynamo/test_callback.py
2025-06-29 17:20:55 +00:00

139 lines
5.4 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
from unittest.mock import Mock
import torch
from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger
from torch._dynamo.test_case import run_tests, TestCase
from torch._guards import CompileId
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import HAS_CUDA
class CallbackTests(TestCase):
def setUp(self) -> None:
super().setUp()
self._on_compile_start = Mock()
self._on_compile_end = Mock()
callback_handler.register_start_callback(self._on_compile_start)
callback_handler.register_end_callback(self._on_compile_end)
def tearDown(self) -> None:
callback_handler.clear()
return super().tearDown()
def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
with (
callback_handler.install_callbacks(trigger, compile_id),
callback_handler.install_callbacks(trigger, compile_id),
):
self._on_compile_start.assert_called_once()
self._on_compile_end.assert_called_once()
def test_counter(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
with callback_handler.install_callbacks(trigger, compile_id):
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
1,
)
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
)
def test_counter_assertion(self) -> None:
callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1
with self.assertRaisesRegex(
AssertionError, "Pending callbacks counter cannot become negative."
):
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
with callback_handler.install_callbacks(trigger, str(compile_id)):
pass
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0
)
@unittest.skipIf(
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
)
@unittest.skipIf(not HAS_CUDA, "requires triton")
@torch._inductor.config.patch(force_disable_caches=True)
def test_triggers(self) -> None:
torch._dynamo.reset()
order = []
def on_start(args: CallbackArgs):
nonlocal order
order.append(f"start={args}")
def on_end(args: CallbackArgs):
nonlocal order
order.append(f"end={args}")
torch._dynamo.callback.on_compile_start(on_start)
torch._dynamo.callback.on_compile_start(on_end)
class TinyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(10, 10)
def forward(self, x):
temp = self.fc1(x)
temp = self.relu(temp)
torch._dynamo.graph_break()
return self.fc2(temp)
model = TinyModel().to("cuda")
compiled_model = torch.compile(model, mode="max-autotune")
x = torch.randn(10, 10, device="cuda")
loss = compiled_model(x).sum()
loss.backward()
self.assertExpectedInline(
"\n".join(order),
"""\
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='0/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.DYNAMO: 1>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id='0/0')""", # noqa: B950
)
order.clear()
compiled_model.zero_grad()
loss = compiled_model(x).sum()
loss.backward()
self.assertExpectedInline(
"\n".join(order),
"""\
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='1/0')
start=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')
end=CallbackArgs(callback_trigger=<CallbackTrigger.CUDAGRAPH_RECORDING: 4>, compile_id='0/0')""", # noqa: B950
)
order.clear()
compiled_model.zero_grad()
loss = compiled_model(x).sum()
loss.backward()
self.assertEqual(len(order), 0)
if __name__ == "__main__":
run_tests()