pytorch/torch/_dynamo/metrics_context.py
Sam Larsen edba20b853 [logging] Fix duration logging for dynamo_compile (#151749)
Summary: There are a few issues I'm solving:.
1. It's too hard to measure total pt2 overhead using the dynamo_compile table because users need to know the columns representing all the top-level events (dynamo_cumulative_compile_time_us, etc.). Instead, let's populate the existing duration_us field for all top-level events. The complication is that runtime events in particular (Triton autotuning, cudagraphify) can be collapsed into a single row, with gaps in between, so we can't simply use `end_time - start_time` in all cases. Instead, we'll sum durations for all outer events when updating the compile-time or runtime metrics context. Introduce a 'depth' counter in TLS to track the nesting of CompilationMetrics events.
2. The existing implementation relies on callers of dynamo_timed to specify whether the event is a runtime or compile-time event. That doesn't work because some methods can be called in both situations, e.g., `CachingAutotuner.benchmark_all_configs`. For example `TORCHINDUCTOR_BENCHMARK_FUSION=1` enables benchmarking during compile-time. Instead, we can figure out automatically whether we're measuring a compile-time or runtime event and log accordingling.
3. If `log_compilation_events` were to throw an exception, we'd fail to clear the aggregated counters for runtime logs and they could be attributed to the wrong compile ID. I didn't actually find evidence of this in practice, but I added exception handling for extra safety.

Test Plan:
Ran internal models and compared dynamo_compile to pt2_compile_events:
`TORCHINDUCTOR_BENCHMARK_FUSION=0`
* tlparse: https://fburl.com/itciwnxc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/yvkif5vb
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/segijet7

`TORCHINDUCTOR_BENCHMARK_FUSION=1`
* tlparse: https://fburl.com/jgurcvkw
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/uum91ceb
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/x4xnisez

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151749
Approved by: https://github.com/Skylion007
2025-04-22 03:29:13 +00:00

230 lines
7.8 KiB
Python

"""Metrics collection and management system for Dynamo.
This module provides context managers for gathering and reporting metrics during
compilation and runtime.
It includes two main components:
- MetricsContext: A context manager for collecting metrics during compilation, supporting
nested contexts and various metric types (counters, sets, key-value pairs)
- RuntimeMetricsContext: A specialized context for runtime metrics collection that doesn't
require explicit context management
The metrics system enables comprehensive monitoring and analysis of both compilation and
execution performance.
"""
import heapq
import logging
import time
from collections.abc import Iterator
from typing import Any, Callable, Optional
from typing_extensions import TypeAlias
log = logging.getLogger(__name__)
class TopN:
"""
Helper to record a list of metrics, keeping only the top N "most expensive" elements.
"""
def __init__(self, at_most: int = 25):
self.at_most = at_most
self.heap: list[tuple[int, Any]] = []
def add(self, key: Any, val: int) -> None:
# Push if we haven't reached the max size, else push and pop the smallest
fn = heapq.heappush if len(self.heap) < self.at_most else heapq.heappushpop
fn(self.heap, (val, key))
def __len__(self) -> int:
return len(self.heap)
def __iter__(self) -> Iterator[tuple[Any, int]]:
return ((key, val) for val, key in sorted(self.heap, reverse=True))
OnExitType: TypeAlias = Callable[
[int, int, dict[str, Any], Optional[type[BaseException]], Optional[BaseException]],
None,
]
class MetricsContext:
def __init__(self, on_exit: OnExitType):
"""
Use this class as a contextmanager to create a context under which to accumulate
a set of metrics, e.g., metrics gathered during a compilation. On exit of the
contextmanager, call the provided 'on_exit' function and pass a dictionary of
all metrics set during the lifetime of the contextmanager.
"""
self._on_exit = on_exit
self._metrics: dict[str, Any] = {}
self._start_time_ns: int = 0
self._level: int = 0
def __enter__(self) -> "MetricsContext":
"""
Initialize metrics recording.
"""
if self._level == 0:
# In case of recursion, track at the outermost context.
self._metrics = {}
self._start_time_ns = time.time_ns()
self._level += 1
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
_traceback: Any,
) -> None:
"""
At exit, call the provided on_exit function.
"""
self._level -= 1
assert self._level >= 0
if self._level == 0:
try:
end_time_ns = time.time_ns()
self._on_exit(
self._start_time_ns, end_time_ns, self._metrics, exc_type, exc_value
)
except Exception:
log.exception("Unexpected exception logging compilation metrics")
def in_progress(self) -> bool:
"""
True if we've entered the context.
"""
return self._level > 0
def increment(self, metric: str, value: int) -> None:
"""
Increment a metric by a given amount.
"""
if self._level == 0:
raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
if metric not in self._metrics:
self._metrics[metric] = 0
self._metrics[metric] += value
def set(self, metric: str, value: Any, overwrite: bool = False) -> None:
"""
Set a metric to a given value. Raises if the metric has been assigned previously
in the current context.
"""
if self._level == 0:
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
if metric in self._metrics and not overwrite:
raise RuntimeError(
f"Metric '{metric}' has already been set in the current context"
)
self._metrics[metric] = value
def set_key_value(self, metric: str, key: str, value: Any) -> None:
"""
Treats a give metric as a dictionary and set the k and value within it.
Note that the metric must be a dictionary or not present.
We allow this to be called multiple times (i.e. for features, it's not uncommon
for them to be used multiple times within a single compilation).
"""
if self._level == 0:
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
if metric not in self._metrics:
self._metrics[metric] = {}
self._metrics[metric][key] = value
def update(self, values: dict[str, Any], overwrite: bool = False) -> None:
"""
Set multiple metrics directly. This method does NOT increment. Raises if any
metric has been assigned previously in the current context and overwrite is
not set to True.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
existing = self._metrics.keys() & values.keys()
if existing and not overwrite:
raise RuntimeError(
f"Metric(s) {existing} have already been set in the current context"
)
self._metrics.update(values)
def update_outer(self, values: dict[str, Any]) -> None:
"""
Update, but only when at the outermost context.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
if self._level == 1:
self.update(values)
def add_to_set(self, metric: str, value: Any) -> None:
"""
Records a metric as a set() of values.
"""
if self._level == 0:
raise RuntimeError(f"Cannot add {metric} outside of a MetricsContext")
if metric not in self._metrics:
self._metrics[metric] = set()
self._metrics[metric].add(value)
def add_top_n(self, metric: str, key: Any, val: int) -> None:
"""
Records a metric as a TopN set of values.
"""
if self._level == 0:
return
if metric not in self._metrics:
self._metrics[metric] = TopN()
self._metrics[metric].add(key, val)
class RuntimeMetricsContext:
def __init__(self, on_exit: OnExitType):
"""
Similar to MetricsContext, but used to gather the runtime metrics that are
decoupled from compilation, where there's not a natural place to insert a
context manager.
"""
self._on_exit = on_exit
self._metrics: dict[str, Any] = {}
self._start_time_ns: int = 0
def increment(
self, metric: str, value: int, extra: Optional[dict[str, Any]] = None
) -> None:
"""
Increment a metric by a given amount.
"""
if not self._metrics:
# Start timing on the first entry
self._start_time_ns = time.time_ns()
if metric not in self._metrics:
self._metrics[metric] = 0
self._metrics[metric] += value
if extra:
for k, v in extra.items():
if k not in self._metrics and v is not None:
self._metrics[k] = v
def finish(self) -> None:
"""
Call the on_exit function with the metrics gathered so far and reset.
"""
if self._metrics:
try:
end_time_ns = time.time_ns()
self._on_exit(
self._start_time_ns, end_time_ns, self._metrics, None, None
)
except Exception:
log.exception("Unexpected exception logging runtime metrics")
finally:
self._metrics = {}