Log graph breaks (#146537)

Graph breaks currently aren't logged to dynamo_compile and pt2_compile_events. We want to log them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146537
Approved by: https://github.com/c00w
This commit is contained in:
Raymond Li 2025-02-27 11:06:33 +00:00 committed by PyTorch MergeBot
parent 0489a349e7
commit c5bf9aaf1c
7 changed files with 113 additions and 20 deletions

View File

@ -71,7 +71,15 @@ class TestMetricsContext(TestCase):
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.update({"m1": 7, "m3": 3})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
def test_update_allow_overwrite(self):
"""
Validate update will overwite when given param.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
context.update({"m1": 7, "m3": 3}, overwrite=True)
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
def test_add_to_set(self):
"""

View File

@ -75,6 +75,71 @@ class TestUtils(TestCase):
)
)
@dynamo_config.patch(
{
"log_compilation_metrics": True,
"inline_inbuilt_nn_modules": False,
}
)
def test_graph_break_counting(self):
"""
Run a compilation that includes a graph break and validate that the
graph break counter is incremented.
"""
def run_forward_backward():
model = torch.compile(TestModel())
x = torch.rand([3], requires_grad=True)
output = model(x)
loss_fn = torch.nn.MSELoss()
target = torch.tensor([1.0])
loss = loss_fn(output, target)
loss.backward()
@torch.compile
def add(x, y):
return x + y
@torch.compile
def break_it(x):
y = x.sum()
if y > 0:
return x + y.item()
return x - y.item()
@torch.compile
def break_it2(x):
y = x.sum()
if y > 0:
if y > 1:
return x * y.item()
return x + y.item()
return x - y.item()
add(torch.rand([10]), torch.rand([10]))
utils.reset_frame_count()
compilation_events = []
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
run_forward_backward()
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 0)
# We should fallback to normal mode and increment the graph break counter
torch.compile(break_it, backend="inductor")(torch.ones(3, 3))
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 1)
# Graph break counter should be incremented by 1 (after a reset), not 2
torch.compile(break_it, backend="inductor")(torch.ones(3, 3))
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 1)
# Graph break counter should be incremented by 2
torch.compile(break_it2, backend="inductor")(torch.ones(3, 3))
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
class TestModel(torch.nn.Module):
def __init__(self):
@ -266,6 +331,7 @@ class TestDynamoTimed(TestCase):
'joint_graph_pass_time_us': 0,
'log_format_version': 3,
'non_compliant_ops': set(),
'num_graph_breaks': 0,
'num_triton_bundles': None,
'post_grad_pass_time_us': 0,
'pre_grad_pass_time_us': 0,
@ -348,6 +414,7 @@ class TestDynamoTimed(TestCase):
'joint_graph_pass_time_us': None,
'log_format_version': 3,
'non_compliant_ops': None,
'num_graph_breaks': 0,
'num_triton_bundles': None,
'post_grad_pass_time_us': 0,
'pre_grad_pass_time_us': None,

View File

@ -1223,13 +1223,11 @@ class ConvertFrame:
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
# BUILD_SET etc. In such case, we can fallback to eager without
# scaring users.
if isinstance(e, Unsupported) and graph_break_log.isEnabledFor(
logging.DEBUG
):
if soft_fail and graph_break_log.isEnabledFor(logging.DEBUG):
# Log this message in the graph break. Also use the string
# "skip: " to tell that the whole frame is falling back to
# eager.
if hasattr(e, "compile_id"):
if hasattr(e, "compile_id") and hasattr(e, "real_stack"):
with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined]
user_stack = e.real_stack
user_stack_formatted = "".join(

View File

@ -109,15 +109,16 @@ class MetricsContext:
self._metrics[metric] = {}
self._metrics[metric][key] = value
def update(self, values: dict[str, Any]) -> None:
def update(self, values: dict[str, Any], overwrite: bool = False) -> None:
"""
Set multiple metrics directly. This method does NOT increment. Raises if any
metric has been assigned previously in the current context.
metric has been assigned previously in the current context and overwrite is
not set to True.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
existing = self._metrics.keys() & values.keys()
if existing:
if existing and not overwrite:
raise RuntimeError(
f"Metric(s) {existing} have already been set in the current context"
)

View File

@ -371,7 +371,12 @@ class CompileEventLogger:
)
@staticmethod
def add_data(event_name: str, log_level: CompileEventLogLevel, **metadata: object):
def add_data(
event_name: str,
log_level: CompileEventLogLevel,
overwrite: bool = False,
**metadata: object,
):
"""
Centralized API for adding data to various events
Log an event to a toplevel "dynamo" event or metrics context
@ -408,11 +413,13 @@ class CompileEventLogger:
)
# TODO: should we assert that the keys of metadata are in CompilationMetrics?
metrics_context.update(metadata)
metrics_context.update(metadata, overwrite)
chromium_log.add_event_data(event_name, **metadata)
@staticmethod
def add_toplevel(log_level: CompileEventLogLevel, **metadata: object):
def add_toplevel(
log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object
):
"""
Syntactic sugar for logging to the toplevel event
"""
@ -421,7 +428,7 @@ class CompileEventLogger:
raise RuntimeError(
"No toplevel event active. Please only call this function within a dynamo_timed context."
)
CompileEventLogger.add_data(top_event, log_level, **metadata)
CompileEventLogger.add_data(top_event, log_level, overwrite, **metadata)
@staticmethod
def increment(
@ -457,7 +464,7 @@ class CompileEventLogger:
@staticmethod
def increment_toplevel(
key: str,
value: int,
value: int = 1,
log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC,
):
"""
@ -529,7 +536,7 @@ class CompileEventLogger:
<event_name> should be the name of a timed event span passed to `dynamo_timed`.
"""
CompileEventLogger.add_data(
event_name, CompileEventLogLevel.CHROMIUM, **metadata
event_name, CompileEventLogLevel.CHROMIUM, overwrite=False, **metadata
)
@staticmethod
@ -542,11 +549,11 @@ class CompileEventLogger:
with log_to_pt2_compile_events=True.
"""
CompileEventLogger.add_data(
event_name, CompileEventLogLevel.PT2_COMPILE, **metadata
event_name, CompileEventLogLevel.PT2_COMPILE, overwrite=False, **metadata
)
@staticmethod
def compilation_metric(**metadata: object):
def compilation_metric(overwrite: bool = False, **metadata: object):
"""
Add <metadata> to the CompilationMetrics context. Also logs to PT2 Compile Events
and chromium.
@ -554,7 +561,7 @@ class CompileEventLogger:
a column in PT2 Compile Events and Dynamo Compile, with the corresponding kwarg value.
"""
CompileEventLogger.add_toplevel(
CompileEventLogLevel.COMPILATION_METRIC, **metadata
CompileEventLogLevel.COMPILATION_METRIC, overwrite, **metadata
)
@staticmethod
@ -1216,6 +1223,7 @@ class CompilationMetrics:
tensorify_float_failure: Optional[set[str]] = None
guard_latency_us: Optional[float] = None
recompile_reason: Optional[str] = None
num_graph_breaks: Optional[int] = None
@classmethod
def create(cls, metrics: dict[str, Any]):

View File

@ -1052,7 +1052,7 @@ class FxGraphCache:
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
if len(meta.cached_kernel_names) > 0:
CompileEventLogger.increment_toplevel("num_triton_bundles", 1)
CompileEventLogger.increment_toplevel("num_triton_bundles")
try:
artifact_path = graph.after_deserialization(constants)
@ -1303,7 +1303,7 @@ class FxGraphCache:
if remote_cache:
# Count remote cache hit stats
CompileEventLogger.increment_toplevel(
"inductor_fx_remote_cache_hit_count", 1
"inductor_fx_remote_cache_hit_count"
)
CompileEventLogger.add_to_set_toplevel(
"inductor_fx_remote_cache_hit_keys", key
@ -1324,7 +1324,7 @@ class FxGraphCache:
if remote_cache:
# Count remote cache miss stats
CompileEventLogger.increment_toplevel(
"inductor_fx_remote_cache_miss_count", 1
"inductor_fx_remote_cache_miss_count"
)
CompileEventLogger.add_to_set_toplevel(
"inductor_fx_remote_cache_miss_keys", key

View File

@ -1015,6 +1015,17 @@ class _InProcessFxCompile(FxCompile):
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
provenance_tracking_json
)
metrics_context = get_metrics_context()
if metrics_context.in_progress():
# TODO: Remove this when 3.9 is no longer supported
if sys.version_info < (3, 10):
num_graph_breaks = sum(counters["graph_break"].values())
else:
num_graph_breaks = counters["graph_break"].total()
CompileEventLogger.compilation_metric(
overwrite=True, num_graph_breaks=num_graph_breaks
)
if config.is_fbcode():
try:
log_optimus_to_scuba(