pytorch/test/dynamo/test_metrics_context.py
Sam Larsen 40c2505f16 [logging] Log individual Triton kernel compilation times to dynamo_compile (#147022)
Summary: Gather the compilation time of individual triton kernels and log them to dynamo_compile:
* Time compilation in `_worker_compile_triton` and pass back to the main process and logged from `get_result()`.
* Added a way to track the "top N" (or N most-expensive compiles) in the metrics_context. I did this because I doubt we really care to capture potentially thousands of kernel compile times. That would be problematic for scuba logging anyway, so let's limit the number we track from the beginning. Arbitrarily chose 25 for now.
* Format the list of compile times as a json string before logging.

Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
Scuba: https://fburl.com/scuba/dynamo_compile/sandbox/nc4dzm3r

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147022
Approved by: https://github.com/jamesjwu
2025-03-03 19:32:17 +00:00

119 lines
3.9 KiB
Python

# Owner(s): ["module: dynamo"]
from torch._dynamo.metrics_context import MetricsContext, TopN
from torch._dynamo.test_case import run_tests, TestCase
class TestMetricsContext(TestCase):
def setUp(self):
super().setUp()
self.metrics = {}
def _on_exit(self, start_ns, end_ns, metrics, exc_type, exc_value):
# Save away the metrics to be validated in the test.
self.metrics = metrics.copy()
def test_context_exists(self):
"""
Setting a value without entering the context should raise.
"""
context = MetricsContext(self._on_exit)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.increment("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.set("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.update({"m", 1})
def test_nested_context(self):
"""
Only the outermost context should get an on_exit call, and it should
include everything.
"""
context = MetricsContext(self._on_exit)
with context:
with context:
context.set("m1", 1)
self.assertEqual(self.metrics, {})
context.set("m2", 2)
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
def test_set(self):
"""
Validate various ways to set metrics.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
context.set("m2", 2)
context.update({"m3": 3, "m4": 4})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
def test_set_disallow_overwrite(self):
"""
Validate set won't overwrite.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.set("m1", 2)
self.assertEqual(self.metrics, {"m1": 1})
def test_update_disallow_overwrite(self):
"""
Validate update won't overwite.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.update({"m1": 7, "m3": 3})
def test_update_allow_overwrite(self):
"""
Validate update will overwite when given param.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
context.update({"m1": 7, "m3": 3}, overwrite=True)
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
def test_add_to_set(self):
"""
Validate add_to_set.
"""
with MetricsContext(self._on_exit) as context:
context.add_to_set("m1", 1)
context.add_to_set("m1", 2)
context.add_to_set("m2", 3)
context.add_to_set("m2", 4)
self.assertEqual(self.metrics, {"m1": {1, 2}, "m2": {3, 4}})
self.assertTrue(isinstance(self.metrics["m1"], set))
self.assertTrue(isinstance(self.metrics["m2"], set))
def test_set_key_value(self):
with MetricsContext(self._on_exit) as context:
context.set_key_value("feature_usage", "k", True)
# Overrides allowed
context.set_key_value("feature_usage", "k2", True)
context.set_key_value("feature_usage", "k2", False)
self.assertEqual(self.metrics, {"feature_usage": {"k": True, "k2": False}})
def test_top_n(self):
top_n = TopN(3)
for k, v in (("seven", 7), ("four", 4), ("five", 5), ("six", 6), ("eight", 8)):
top_n.add(k, v)
self.assertEqual(len(top_n), 3)
print(list(top_n))
self.assertEqual(list(top_n), [("eight", 8), ("seven", 7), ("six", 6)])
if __name__ == "__main__":
run_tests()