mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Log graph breaks (#146537)
Graph breaks currently aren't logged to dynamo_compile and pt2_compile_events. We want to log them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146537 Approved by: https://github.com/c00w
This commit is contained in:
parent
0489a349e7
commit
c5bf9aaf1c
|
|
@ -71,7 +71,15 @@ class TestMetricsContext(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
context.update({"m1": 7, "m3": 3})
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
||||
def test_update_allow_overwrite(self):
|
||||
"""
|
||||
Validate update will overwite when given param.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.update({"m1": 1, "m2": 2})
|
||||
context.update({"m1": 7, "m3": 3}, overwrite=True)
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
|
||||
|
||||
def test_add_to_set(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -75,6 +75,71 @@ class TestUtils(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"log_compilation_metrics": True,
|
||||
"inline_inbuilt_nn_modules": False,
|
||||
}
|
||||
)
|
||||
def test_graph_break_counting(self):
|
||||
"""
|
||||
Run a compilation that includes a graph break and validate that the
|
||||
graph break counter is incremented.
|
||||
"""
|
||||
|
||||
def run_forward_backward():
|
||||
model = torch.compile(TestModel())
|
||||
x = torch.rand([3], requires_grad=True)
|
||||
output = model(x)
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
target = torch.tensor([1.0])
|
||||
loss = loss_fn(output, target)
|
||||
loss.backward()
|
||||
|
||||
@torch.compile
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@torch.compile
|
||||
def break_it(x):
|
||||
y = x.sum()
|
||||
if y > 0:
|
||||
return x + y.item()
|
||||
return x - y.item()
|
||||
|
||||
@torch.compile
|
||||
def break_it2(x):
|
||||
y = x.sum()
|
||||
if y > 0:
|
||||
if y > 1:
|
||||
return x * y.item()
|
||||
return x + y.item()
|
||||
return x - y.item()
|
||||
|
||||
add(torch.rand([10]), torch.rand([10]))
|
||||
utils.reset_frame_count()
|
||||
|
||||
compilation_events = []
|
||||
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
|
||||
run_forward_backward()
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 0)
|
||||
|
||||
# We should fallback to normal mode and increment the graph break counter
|
||||
torch.compile(break_it, backend="inductor")(torch.ones(3, 3))
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 1)
|
||||
|
||||
# Graph break counter should be incremented by 1 (after a reset), not 2
|
||||
torch.compile(break_it, backend="inductor")(torch.ones(3, 3))
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 1)
|
||||
|
||||
# Graph break counter should be incremented by 2
|
||||
torch.compile(break_it2, backend="inductor")(torch.ones(3, 3))
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -266,6 +331,7 @@ class TestDynamoTimed(TestCase):
|
|||
'joint_graph_pass_time_us': 0,
|
||||
'log_format_version': 3,
|
||||
'non_compliant_ops': set(),
|
||||
'num_graph_breaks': 0,
|
||||
'num_triton_bundles': None,
|
||||
'post_grad_pass_time_us': 0,
|
||||
'pre_grad_pass_time_us': 0,
|
||||
|
|
@ -348,6 +414,7 @@ class TestDynamoTimed(TestCase):
|
|||
'joint_graph_pass_time_us': None,
|
||||
'log_format_version': 3,
|
||||
'non_compliant_ops': None,
|
||||
'num_graph_breaks': 0,
|
||||
'num_triton_bundles': None,
|
||||
'post_grad_pass_time_us': 0,
|
||||
'pre_grad_pass_time_us': None,
|
||||
|
|
|
|||
|
|
@ -1223,13 +1223,11 @@ class ConvertFrame:
|
|||
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
||||
# BUILD_SET etc. In such case, we can fallback to eager without
|
||||
# scaring users.
|
||||
if isinstance(e, Unsupported) and graph_break_log.isEnabledFor(
|
||||
logging.DEBUG
|
||||
):
|
||||
if soft_fail and graph_break_log.isEnabledFor(logging.DEBUG):
|
||||
# Log this message in the graph break. Also use the string
|
||||
# "skip: " to tell that the whole frame is falling back to
|
||||
# eager.
|
||||
if hasattr(e, "compile_id"):
|
||||
if hasattr(e, "compile_id") and hasattr(e, "real_stack"):
|
||||
with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined]
|
||||
user_stack = e.real_stack
|
||||
user_stack_formatted = "".join(
|
||||
|
|
|
|||
|
|
@ -109,15 +109,16 @@ class MetricsContext:
|
|||
self._metrics[metric] = {}
|
||||
self._metrics[metric][key] = value
|
||||
|
||||
def update(self, values: dict[str, Any]) -> None:
|
||||
def update(self, values: dict[str, Any], overwrite: bool = False) -> None:
|
||||
"""
|
||||
Set multiple metrics directly. This method does NOT increment. Raises if any
|
||||
metric has been assigned previously in the current context.
|
||||
metric has been assigned previously in the current context and overwrite is
|
||||
not set to True.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
|
||||
existing = self._metrics.keys() & values.keys()
|
||||
if existing:
|
||||
if existing and not overwrite:
|
||||
raise RuntimeError(
|
||||
f"Metric(s) {existing} have already been set in the current context"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -371,7 +371,12 @@ class CompileEventLogger:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def add_data(event_name: str, log_level: CompileEventLogLevel, **metadata: object):
|
||||
def add_data(
|
||||
event_name: str,
|
||||
log_level: CompileEventLogLevel,
|
||||
overwrite: bool = False,
|
||||
**metadata: object,
|
||||
):
|
||||
"""
|
||||
Centralized API for adding data to various events
|
||||
Log an event to a toplevel "dynamo" event or metrics context
|
||||
|
|
@ -408,11 +413,13 @@ class CompileEventLogger:
|
|||
)
|
||||
|
||||
# TODO: should we assert that the keys of metadata are in CompilationMetrics?
|
||||
metrics_context.update(metadata)
|
||||
metrics_context.update(metadata, overwrite)
|
||||
chromium_log.add_event_data(event_name, **metadata)
|
||||
|
||||
@staticmethod
|
||||
def add_toplevel(log_level: CompileEventLogLevel, **metadata: object):
|
||||
def add_toplevel(
|
||||
log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object
|
||||
):
|
||||
"""
|
||||
Syntactic sugar for logging to the toplevel event
|
||||
"""
|
||||
|
|
@ -421,7 +428,7 @@ class CompileEventLogger:
|
|||
raise RuntimeError(
|
||||
"No toplevel event active. Please only call this function within a dynamo_timed context."
|
||||
)
|
||||
CompileEventLogger.add_data(top_event, log_level, **metadata)
|
||||
CompileEventLogger.add_data(top_event, log_level, overwrite, **metadata)
|
||||
|
||||
@staticmethod
|
||||
def increment(
|
||||
|
|
@ -457,7 +464,7 @@ class CompileEventLogger:
|
|||
@staticmethod
|
||||
def increment_toplevel(
|
||||
key: str,
|
||||
value: int,
|
||||
value: int = 1,
|
||||
log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC,
|
||||
):
|
||||
"""
|
||||
|
|
@ -529,7 +536,7 @@ class CompileEventLogger:
|
|||
<event_name> should be the name of a timed event span passed to `dynamo_timed`.
|
||||
"""
|
||||
CompileEventLogger.add_data(
|
||||
event_name, CompileEventLogLevel.CHROMIUM, **metadata
|
||||
event_name, CompileEventLogLevel.CHROMIUM, overwrite=False, **metadata
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -542,11 +549,11 @@ class CompileEventLogger:
|
|||
with log_to_pt2_compile_events=True.
|
||||
"""
|
||||
CompileEventLogger.add_data(
|
||||
event_name, CompileEventLogLevel.PT2_COMPILE, **metadata
|
||||
event_name, CompileEventLogLevel.PT2_COMPILE, overwrite=False, **metadata
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compilation_metric(**metadata: object):
|
||||
def compilation_metric(overwrite: bool = False, **metadata: object):
|
||||
"""
|
||||
Add <metadata> to the CompilationMetrics context. Also logs to PT2 Compile Events
|
||||
and chromium.
|
||||
|
|
@ -554,7 +561,7 @@ class CompileEventLogger:
|
|||
a column in PT2 Compile Events and Dynamo Compile, with the corresponding kwarg value.
|
||||
"""
|
||||
CompileEventLogger.add_toplevel(
|
||||
CompileEventLogLevel.COMPILATION_METRIC, **metadata
|
||||
CompileEventLogLevel.COMPILATION_METRIC, overwrite, **metadata
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1216,6 +1223,7 @@ class CompilationMetrics:
|
|||
tensorify_float_failure: Optional[set[str]] = None
|
||||
guard_latency_us: Optional[float] = None
|
||||
recompile_reason: Optional[str] = None
|
||||
num_graph_breaks: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, metrics: dict[str, Any]):
|
||||
|
|
|
|||
|
|
@ -1052,7 +1052,7 @@ class FxGraphCache:
|
|||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
CompileEventLogger.increment_toplevel("num_triton_bundles", 1)
|
||||
CompileEventLogger.increment_toplevel("num_triton_bundles")
|
||||
|
||||
try:
|
||||
artifact_path = graph.after_deserialization(constants)
|
||||
|
|
@ -1303,7 +1303,7 @@ class FxGraphCache:
|
|||
if remote_cache:
|
||||
# Count remote cache hit stats
|
||||
CompileEventLogger.increment_toplevel(
|
||||
"inductor_fx_remote_cache_hit_count", 1
|
||||
"inductor_fx_remote_cache_hit_count"
|
||||
)
|
||||
CompileEventLogger.add_to_set_toplevel(
|
||||
"inductor_fx_remote_cache_hit_keys", key
|
||||
|
|
@ -1324,7 +1324,7 @@ class FxGraphCache:
|
|||
if remote_cache:
|
||||
# Count remote cache miss stats
|
||||
CompileEventLogger.increment_toplevel(
|
||||
"inductor_fx_remote_cache_miss_count", 1
|
||||
"inductor_fx_remote_cache_miss_count"
|
||||
)
|
||||
CompileEventLogger.add_to_set_toplevel(
|
||||
"inductor_fx_remote_cache_miss_keys", key
|
||||
|
|
|
|||
|
|
@ -1015,6 +1015,17 @@ class _InProcessFxCompile(FxCompile):
|
|||
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
|
||||
provenance_tracking_json
|
||||
)
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
if metrics_context.in_progress():
|
||||
# TODO: Remove this when 3.9 is no longer supported
|
||||
if sys.version_info < (3, 10):
|
||||
num_graph_breaks = sum(counters["graph_break"].values())
|
||||
else:
|
||||
num_graph_breaks = counters["graph_break"].total()
|
||||
CompileEventLogger.compilation_metric(
|
||||
overwrite=True, num_graph_breaks=num_graph_breaks
|
||||
)
|
||||
if config.is_fbcode():
|
||||
try:
|
||||
log_optimus_to_scuba(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user