mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Log backward no-op to tlparse and pt2 compile events. (#154544)
Summary: Log backward no-op to tlparse and pt2 compile events. Test Plan: $ rm -rf /tmp/r && TORCH_TRACE=/tmp/r buck2 run //scripts/jovian:backward_noop_repro_compile Used print statements to verify we enter the logging code region. Differential Revision: D75231665 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154544 Approved by: https://github.com/c00w
This commit is contained in:
parent
2e2ea7290a
commit
1ccc57e428
|
|
@ -2590,6 +2590,29 @@ class AOTInductorTestsTemplate:
|
|||
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@patch("torch._dynamo.utils.CompileEventLogger.log_instant_event")
|
||||
def test_backward_no_op_logging(self, mock_log_instant_event):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
model = Model()
|
||||
dummy_input = torch.randn(1, 5)
|
||||
|
||||
from torch._dynamo.utils import CompileEventLogLevel
|
||||
from torch._inductor import compile_fx
|
||||
|
||||
graph_module = torch.fx.symbolic_trace(model)
|
||||
compile_fx._compile_fx_inner(graph_module, (dummy_input,))
|
||||
mock_log_instant_event.assert_called_once_with(
|
||||
"backward no-op",
|
||||
metadata={"compile_id": None},
|
||||
log_level=CompileEventLogLevel.PT2_COMPILE,
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
||||
def test_dup_unbacked_sym_decl(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -742,9 +742,17 @@ def _compile_fx_inner(
|
|||
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
|
||||
# trigger the real recompilation for _LazyGraphModule before returning
|
||||
# the forward method.
|
||||
from torch._dynamo.utils import CompileEventLogLevel
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(gm)
|
||||
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||
CompileEventLogger.log_instant_event(
|
||||
"backward no-op",
|
||||
metadata={"compile_id": compile_id},
|
||||
log_level=CompileEventLogLevel.PT2_COMPILE,
|
||||
)
|
||||
|
||||
return make_boxed_func(gm.forward)
|
||||
|
||||
static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user