Log backward no-op to tlparse and pt2 compile events. (#154544)

Summary: Log backward no-op to tlparse and pt2 compile events.

Test Plan:
$ rm -rf /tmp/r && TORCH_TRACE=/tmp/r buck2 run //scripts/jovian:backward_noop_repro_compile

Used print statements to verify we enter the logging code region.

Differential Revision: D75231665

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154544
Approved by: https://github.com/c00w
This commit is contained in:
Jovian Anthony Jaison 2025-06-06 18:08:19 +00:00 committed by PyTorch MergeBot
parent 2e2ea7290a
commit 1ccc57e428
2 changed files with 31 additions and 0 deletions

View File

@ -2590,6 +2590,29 @@ class AOTInductorTestsTemplate:
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
self.check_model(Model(), example_inputs)
@patch("torch._dynamo.utils.CompileEventLogger.log_instant_event")
def test_backward_no_op_logging(self, mock_log_instant_event):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
model = Model()
dummy_input = torch.randn(1, 5)
from torch._dynamo.utils import CompileEventLogLevel
from torch._inductor import compile_fx
graph_module = torch.fx.symbolic_trace(model)
compile_fx._compile_fx_inner(graph_module, (dummy_input,))
mock_log_instant_event.assert_called_once_with(
"backward no-op",
metadata={"compile_id": None},
log_level=CompileEventLogLevel.PT2_COMPILE,
)
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_dup_unbacked_sym_decl(self):
class Model(torch.nn.Module):

View File

@ -742,9 +742,17 @@ def _compile_fx_inner(
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
# trigger the real recompilation for _LazyGraphModule before returning
# the forward method.
from torch._dynamo.utils import CompileEventLogLevel
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(gm)
compile_id = torch._guards.CompileContext.current_compile_id()
CompileEventLogger.log_instant_event(
"backward no-op",
metadata={"compile_id": compile_id},
log_level=CompileEventLogLevel.PT2_COMPILE,
)
return make_boxed_func(gm.forward)
static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ())