mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Gather the compilation time of individual triton kernels and log them to dynamo_compile: * Time compilation in `_worker_compile_triton` and pass back to the main process and logged from `get_result()`. * Added a way to track the "top N" (or N most-expensive compiles) in the metrics_context. I did this because I doubt we really care to capture potentially thousands of kernel compile times. That would be problematic for scuba logging anyway, so let's limit the number we track from the beginning. Arbitrarily chose 25 for now. * Format the list of compile times as a json string before logging. Test Plan: `python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt` Scuba: https://fburl.com/scuba/dynamo_compile/sandbox/nc4dzm3r Pull Request resolved: https://github.com/pytorch/pytorch/pull/147022 Approved by: https://github.com/jamesjwu
119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
from torch._dynamo.metrics_context import MetricsContext, TopN
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
|
|
|
|
class TestMetricsContext(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.metrics = {}
|
|
|
|
def _on_exit(self, start_ns, end_ns, metrics, exc_type, exc_value):
|
|
# Save away the metrics to be validated in the test.
|
|
self.metrics = metrics.copy()
|
|
|
|
def test_context_exists(self):
|
|
"""
|
|
Setting a value without entering the context should raise.
|
|
"""
|
|
context = MetricsContext(self._on_exit)
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.increment("m", 1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.set("m", 1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.update({"m", 1})
|
|
|
|
def test_nested_context(self):
|
|
"""
|
|
Only the outermost context should get an on_exit call, and it should
|
|
include everything.
|
|
"""
|
|
context = MetricsContext(self._on_exit)
|
|
with context:
|
|
with context:
|
|
context.set("m1", 1)
|
|
self.assertEqual(self.metrics, {})
|
|
context.set("m2", 2)
|
|
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
|
|
|
def test_set(self):
|
|
"""
|
|
Validate various ways to set metrics.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set("m1", 1)
|
|
context.set("m2", 2)
|
|
context.update({"m3": 3, "m4": 4})
|
|
|
|
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
|
|
|
|
def test_set_disallow_overwrite(self):
|
|
"""
|
|
Validate set won't overwrite.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set("m1", 1)
|
|
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
|
context.set("m1", 2)
|
|
|
|
self.assertEqual(self.metrics, {"m1": 1})
|
|
|
|
def test_update_disallow_overwrite(self):
|
|
"""
|
|
Validate update won't overwite.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.update({"m1": 1, "m2": 2})
|
|
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
|
context.update({"m1": 7, "m3": 3})
|
|
|
|
def test_update_allow_overwrite(self):
|
|
"""
|
|
Validate update will overwite when given param.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.update({"m1": 1, "m2": 2})
|
|
context.update({"m1": 7, "m3": 3}, overwrite=True)
|
|
|
|
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
|
|
|
|
def test_add_to_set(self):
|
|
"""
|
|
Validate add_to_set.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.add_to_set("m1", 1)
|
|
context.add_to_set("m1", 2)
|
|
context.add_to_set("m2", 3)
|
|
context.add_to_set("m2", 4)
|
|
|
|
self.assertEqual(self.metrics, {"m1": {1, 2}, "m2": {3, 4}})
|
|
self.assertTrue(isinstance(self.metrics["m1"], set))
|
|
self.assertTrue(isinstance(self.metrics["m2"], set))
|
|
|
|
def test_set_key_value(self):
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set_key_value("feature_usage", "k", True)
|
|
# Overrides allowed
|
|
context.set_key_value("feature_usage", "k2", True)
|
|
context.set_key_value("feature_usage", "k2", False)
|
|
|
|
self.assertEqual(self.metrics, {"feature_usage": {"k": True, "k2": False}})
|
|
|
|
def test_top_n(self):
|
|
top_n = TopN(3)
|
|
for k, v in (("seven", 7), ("four", 4), ("five", 5), ("six", 6), ("eight", 8)):
|
|
top_n.add(k, v)
|
|
|
|
self.assertEqual(len(top_n), 3)
|
|
print(list(top_n))
|
|
self.assertEqual(list(top_n), [("eight", 8), ("seven", 7), ("six", 6)])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|