diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index a9d96f1aa7c..d2d1af1a8fd 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -985,6 +985,50 @@ class TestProfiler(TestCase): ) self.assertIn("Total MFLOPs", profiler_output) + def test_override_time_units(self): + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + + model = torch.nn.Sequential( + nn.Conv2d(16, 33, 18), + nn.ReLU(), + nn.Linear(243, 243), + nn.ReLU(), + ) + inputs = torch.randn(40, 16, 18, 260) + with _profile() as prof: + model(inputs) + + profiler_output = prof.key_averages().table(time_unit="s") + self.assertRegex(profiler_output, r".*(\.[0-9]{3}s).*") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*") + for event in prof.key_averages(): + cpu_time_str_s = f"{event.cpu_time / US_IN_SECOND:.3f}s" + cpu_time_total_str_s = f"{event.cpu_time_total / US_IN_SECOND:.3f}s" + self.assertTrue(cpu_time_str_s in profiler_output) + self.assertTrue(cpu_time_total_str_s in profiler_output) + + profiler_output = prof.key_averages().table(time_unit="ms") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*") + self.assertRegex(profiler_output, r".*(\.[0-9]{3}ms).*") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*") + for event in prof.key_averages(): + cpu_time_str_ms = f"{event.cpu_time / US_IN_MS:.3f}ms" + cpu_time_total_str_ms = f"{event.cpu_time_total / US_IN_MS:.3f}ms" + self.assertTrue(cpu_time_str_ms in profiler_output) + self.assertTrue(cpu_time_total_str_ms in profiler_output) + + profiler_output = prof.key_averages().table(time_unit="us") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*") + self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*") + self.assertRegex(profiler_output, r".*(\.[0-9]{3}us).*") + for event in prof.key_averages(): + cpu_time_str_us = f"{event.cpu_time:.3f}us" + cpu_time_total_str_us = f"{event.cpu_time_total:.3f}us" + self.assertTrue(cpu_time_str_us in profiler_output) + self.assertTrue(cpu_time_total_str_us in profiler_output) + @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) def test_kineto_profiler_api(self): diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 84dd5f1013d..b789aab11c6 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -173,6 +173,7 @@ class EventList(list): max_shapes_column_width=80, header=None, top_level_events_only=False, + time_unit=None, ): """Print an EventList as a nicely formatted table. @@ -189,6 +190,8 @@ class EventList(list): display events at top level like top-level invocation of python `lstm`, python `add` or other functions, nested events like low-level cpu/cuda/xpu ops events are omitted for profiler result readability. + time_unit(str, optional): A time unit to be used for all values in the + table. Valid options are: ``s``, ``ms`` and ``us``. Returns: A string containing the table. @@ -204,6 +207,7 @@ class EventList(list): profile_memory=self._profile_memory, with_flops=self._with_flops, top_level_events_only=top_level_events_only, + time_unit=time_unit, ) def export_chrome_trace(self, path): @@ -832,6 +836,7 @@ def _build_table( with_flops=False, profile_memory=False, top_level_events_only=False, + time_unit=None, ): """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" if len(events) == 0: @@ -1039,6 +1044,18 @@ def _build_table( path = "..." + path[3:] return path + def override_time_unit(time_us, default_str, time_unit): + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_unit == "s": + return f"{time_us / US_IN_SECOND:.3f}s" + elif time_unit == "ms": + return f"{time_us / US_IN_MS:.3f}ms" + elif time_unit == "us": + return f"{time_us:.3f}us" + else: + return default_str + event_limit = 0 for evt in events: if event_limit == row_limit: @@ -1072,11 +1089,17 @@ def _build_table( row_values += [ # Self CPU total %, 0 for async events. evt.self_cpu_percent, - evt.self_cpu_time_total_str, # Self CPU total + override_time_unit( + evt.self_cpu_time_total, evt.self_cpu_time_total_str, time_unit + ), # Self CPU total # CPU total %, 0 for async events. evt.total_cpu_percent, - evt.cpu_time_total_str, # CPU total - evt.cpu_time_str, # CPU time avg + override_time_unit( + evt.cpu_time_total, evt.cpu_time_total_str, time_unit + ), # CPU total + override_time_unit( + evt.cpu_time, evt.cpu_time_str, time_unit + ), # CPU time avg ] if has_device_time: evt.total_device_percent = _format_time_share( @@ -1084,11 +1107,19 @@ def _build_table( ) row_values.extend( [ - evt.self_device_time_total_str, + override_time_unit( + evt.self_device_time_total, + evt.self_device_time_total_str, + time_unit, + ), # device time total % evt.total_device_percent, - evt.device_time_total_str, - evt.device_time_str, # device time avg + override_time_unit( + evt.device_time_total, evt.device_time_total_str, time_unit + ), + override_time_unit( + evt.device_time, evt.device_time_str, time_unit + ), # device time avg ] ) if profile_memory: @@ -1141,10 +1172,12 @@ def _build_table( append(row_format.format(*empty_headers)) append(header_sep) - append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") + append( + f"Self CPU time total: {override_time_unit(sum_self_cpu_time_total, _format_time(sum_self_cpu_time_total), time_unit)}" + ) if has_device_time: append( f"Self {use_device.upper() if use_device is not None else 'None'} " - f"time total: {_format_time(sum_self_device_time_total)}" + f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result)