Allow Custom Time Unit When Printing Profiler Table (#157913)

## Overview
This PR adds a kwarg to the `table()` method of the profiler allowing users to specify a time unit to be used for all results in the profiling table. The available options are: `s`, `ms` and `us`. If an invalid unit or no unit is provided, then a time unit is selected based on the size of the value (current default behaviour).

## Testing
A unit test has been added to verify this works correctly.

## Documentation
I couldn't find any documentation specific to the `table()` function beyond doc strings which have been updated.

## Example Output
```
import torch
from torch.profiler import profile

with profile() as prof:
    res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))

print(prof.key_averages().table(time_unit="s"))
print(prof.key_averages().table(time_unit="ms"))
print(prof.key_averages().table(time_unit="us"))
print(prof.key_averages().table())

```

```
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
            aten::rand         0.04%        0.000s        10.36%        0.014s        0.007s             2
           aten::empty         0.04%        0.000s         0.04%        0.000s        0.000s             2
        aten::uniform_        10.27%        0.014s        10.27%        0.014s        0.007s             2
              aten::mm        89.64%        0.119s        89.64%        0.119s        0.119s             1
    aten::resolve_conj         0.00%        0.000s         0.00%        0.000s        0.000s             3
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 0.133s

----------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
            aten::rand         0.04%       0.055ms        10.36%      13.735ms       6.868ms             2
           aten::empty         0.04%       0.054ms         0.04%       0.054ms       0.027ms             2
        aten::uniform_        10.27%      13.626ms        10.27%      13.626ms       6.813ms             2
              aten::mm        89.64%     118.892ms        89.64%     118.896ms     118.896ms             1
    aten::resolve_conj         0.00%       0.004ms         0.00%       0.004ms       0.001ms             3
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 132.631ms

----------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
            aten::rand         0.04%      55.495us        10.36%   13735.202us    6867.601us             2
           aten::empty         0.04%      54.121us         0.04%      54.121us      27.061us             2
        aten::uniform_        10.27%   13625.586us        10.27%   13625.586us    6812.793us             2
              aten::mm        89.64%  118892.284us        89.64%  118895.981us  118895.981us             1
    aten::resolve_conj         0.00%       3.697us         0.00%       3.697us       1.232us             3
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 132631.183us

----------------------  ------------  ------------  ------------  ------------  ------------  ------------
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
            aten::rand         0.04%      55.495us        10.36%      13.735ms       6.868ms             2
           aten::empty         0.04%      54.121us         0.04%      54.121us      27.061us             2
        aten::uniform_        10.27%      13.626ms        10.27%      13.626ms       6.813ms             2
              aten::mm        89.64%     118.892ms        89.64%     118.896ms     118.896ms             1
    aten::resolve_conj         0.00%       3.697us         0.00%       3.697us       1.232us             3
----------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 132.631ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157913
Approved by: https://github.com/sraikund16
This commit is contained in:
George Wigley 2025-07-10 22:44:29 +00:00 committed by PyTorch MergeBot
parent 83700b4488
commit 9bf41633d7
2 changed files with 85 additions and 8 deletions

View File

@ -985,6 +985,50 @@ class TestProfiler(TestCase):
)
self.assertIn("Total MFLOPs", profiler_output)
def test_override_time_units(self):
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
model = torch.nn.Sequential(
nn.Conv2d(16, 33, 18),
nn.ReLU(),
nn.Linear(243, 243),
nn.ReLU(),
)
inputs = torch.randn(40, 16, 18, 260)
with _profile() as prof:
model(inputs)
profiler_output = prof.key_averages().table(time_unit="s")
self.assertRegex(profiler_output, r".*(\.[0-9]{3}s).*")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*")
for event in prof.key_averages():
cpu_time_str_s = f"{event.cpu_time / US_IN_SECOND:.3f}s"
cpu_time_total_str_s = f"{event.cpu_time_total / US_IN_SECOND:.3f}s"
self.assertTrue(cpu_time_str_s in profiler_output)
self.assertTrue(cpu_time_total_str_s in profiler_output)
profiler_output = prof.key_averages().table(time_unit="ms")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*")
self.assertRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*")
for event in prof.key_averages():
cpu_time_str_ms = f"{event.cpu_time / US_IN_MS:.3f}ms"
cpu_time_total_str_ms = f"{event.cpu_time_total / US_IN_MS:.3f}ms"
self.assertTrue(cpu_time_str_ms in profiler_output)
self.assertTrue(cpu_time_total_str_ms in profiler_output)
profiler_output = prof.key_averages().table(time_unit="us")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*")
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
self.assertRegex(profiler_output, r".*(\.[0-9]{3}us).*")
for event in prof.key_averages():
cpu_time_str_us = f"{event.cpu_time:.3f}us"
cpu_time_total_str_us = f"{event.cpu_time_total:.3f}us"
self.assertTrue(cpu_time_str_us in profiler_output)
self.assertTrue(cpu_time_total_str_us in profiler_output)
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
@patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
def test_kineto_profiler_api(self):

View File

@ -173,6 +173,7 @@ class EventList(list):
max_shapes_column_width=80,
header=None,
top_level_events_only=False,
time_unit=None,
):
"""Print an EventList as a nicely formatted table.
@ -189,6 +190,8 @@ class EventList(list):
display events at top level like top-level invocation of python
`lstm`, python `add` or other functions, nested events like low-level
cpu/cuda/xpu ops events are omitted for profiler result readability.
time_unit(str, optional): A time unit to be used for all values in the
table. Valid options are: ``s``, ``ms`` and ``us``.
Returns:
A string containing the table.
@ -204,6 +207,7 @@ class EventList(list):
profile_memory=self._profile_memory,
with_flops=self._with_flops,
top_level_events_only=top_level_events_only,
time_unit=time_unit,
)
def export_chrome_trace(self, path):
@ -832,6 +836,7 @@ def _build_table(
with_flops=False,
profile_memory=False,
top_level_events_only=False,
time_unit=None,
):
"""Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
if len(events) == 0:
@ -1039,6 +1044,18 @@ def _build_table(
path = "..." + path[3:]
return path
def override_time_unit(time_us, default_str, time_unit):
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
if time_unit == "s":
return f"{time_us / US_IN_SECOND:.3f}s"
elif time_unit == "ms":
return f"{time_us / US_IN_MS:.3f}ms"
elif time_unit == "us":
return f"{time_us:.3f}us"
else:
return default_str
event_limit = 0
for evt in events:
if event_limit == row_limit:
@ -1072,11 +1089,17 @@ def _build_table(
row_values += [
# Self CPU total %, 0 for async events.
evt.self_cpu_percent,
evt.self_cpu_time_total_str, # Self CPU total
override_time_unit(
evt.self_cpu_time_total, evt.self_cpu_time_total_str, time_unit
), # Self CPU total
# CPU total %, 0 for async events.
evt.total_cpu_percent,
evt.cpu_time_total_str, # CPU total
evt.cpu_time_str, # CPU time avg
override_time_unit(
evt.cpu_time_total, evt.cpu_time_total_str, time_unit
), # CPU total
override_time_unit(
evt.cpu_time, evt.cpu_time_str, time_unit
), # CPU time avg
]
if has_device_time:
evt.total_device_percent = _format_time_share(
@ -1084,11 +1107,19 @@ def _build_table(
)
row_values.extend(
[
evt.self_device_time_total_str,
override_time_unit(
evt.self_device_time_total,
evt.self_device_time_total_str,
time_unit,
),
# device time total %
evt.total_device_percent,
evt.device_time_total_str,
evt.device_time_str, # device time avg
override_time_unit(
evt.device_time_total, evt.device_time_total_str, time_unit
),
override_time_unit(
evt.device_time, evt.device_time_str, time_unit
), # device time avg
]
)
if profile_memory:
@ -1141,10 +1172,12 @@ def _build_table(
append(row_format.format(*empty_headers))
append(header_sep)
append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}")
append(
f"Self CPU time total: {override_time_unit(sum_self_cpu_time_total, _format_time(sum_self_cpu_time_total), time_unit)}"
)
if has_device_time:
append(
f"Self {use_device.upper() if use_device is not None else 'None'} "
f"time total: {_format_time(sum_self_device_time_total)}"
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)