mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow Custom Time Unit When Printing Profiler Table (#157913)
## Overview
This PR adds a kwarg to the `table()` method of the profiler allowing users to specify a time unit to be used for all results in the profiling table. The available options are: `s`, `ms` and `us`. If an invalid unit or no unit is provided, then a time unit is selected based on the size of the value (current default behaviour).
## Testing
A unit test has been added to verify this works correctly.
## Documentation
I couldn't find any documentation specific to the `table()` function beyond doc strings which have been updated.
## Example Output
```
import torch
from torch.profiler import profile
with profile() as prof:
res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))
print(prof.key_averages().table(time_unit="s"))
print(prof.key_averages().table(time_unit="ms"))
print(prof.key_averages().table(time_unit="us"))
print(prof.key_averages().table())
```
```
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::rand 0.04% 0.000s 10.36% 0.014s 0.007s 2
aten::empty 0.04% 0.000s 0.04% 0.000s 0.000s 2
aten::uniform_ 10.27% 0.014s 10.27% 0.014s 0.007s 2
aten::mm 89.64% 0.119s 89.64% 0.119s 0.119s 1
aten::resolve_conj 0.00% 0.000s 0.00% 0.000s 0.000s 3
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 0.133s
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::rand 0.04% 0.055ms 10.36% 13.735ms 6.868ms 2
aten::empty 0.04% 0.054ms 0.04% 0.054ms 0.027ms 2
aten::uniform_ 10.27% 13.626ms 10.27% 13.626ms 6.813ms 2
aten::mm 89.64% 118.892ms 89.64% 118.896ms 118.896ms 1
aten::resolve_conj 0.00% 0.004ms 0.00% 0.004ms 0.001ms 3
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 132.631ms
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::rand 0.04% 55.495us 10.36% 13735.202us 6867.601us 2
aten::empty 0.04% 54.121us 0.04% 54.121us 27.061us 2
aten::uniform_ 10.27% 13625.586us 10.27% 13625.586us 6812.793us 2
aten::mm 89.64% 118892.284us 89.64% 118895.981us 118895.981us 1
aten::resolve_conj 0.00% 3.697us 0.00% 3.697us 1.232us 3
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 132631.183us
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::rand 0.04% 55.495us 10.36% 13.735ms 6.868ms 2
aten::empty 0.04% 54.121us 0.04% 54.121us 27.061us 2
aten::uniform_ 10.27% 13.626ms 10.27% 13.626ms 6.813ms 2
aten::mm 89.64% 118.892ms 89.64% 118.896ms 118.896ms 1
aten::resolve_conj 0.00% 3.697us 0.00% 3.697us 1.232us 3
---------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 132.631ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157913
Approved by: https://github.com/sraikund16
This commit is contained in:
parent
83700b4488
commit
9bf41633d7
|
|
@ -985,6 +985,50 @@ class TestProfiler(TestCase):
|
|||
)
|
||||
self.assertIn("Total MFLOPs", profiler_output)
|
||||
|
||||
def test_override_time_units(self):
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
nn.Conv2d(16, 33, 18),
|
||||
nn.ReLU(),
|
||||
nn.Linear(243, 243),
|
||||
nn.ReLU(),
|
||||
)
|
||||
inputs = torch.randn(40, 16, 18, 260)
|
||||
with _profile() as prof:
|
||||
model(inputs)
|
||||
|
||||
profiler_output = prof.key_averages().table(time_unit="s")
|
||||
self.assertRegex(profiler_output, r".*(\.[0-9]{3}s).*")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*")
|
||||
for event in prof.key_averages():
|
||||
cpu_time_str_s = f"{event.cpu_time / US_IN_SECOND:.3f}s"
|
||||
cpu_time_total_str_s = f"{event.cpu_time_total / US_IN_SECOND:.3f}s"
|
||||
self.assertTrue(cpu_time_str_s in profiler_output)
|
||||
self.assertTrue(cpu_time_total_str_s in profiler_output)
|
||||
|
||||
profiler_output = prof.key_averages().table(time_unit="ms")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*")
|
||||
self.assertRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*")
|
||||
for event in prof.key_averages():
|
||||
cpu_time_str_ms = f"{event.cpu_time / US_IN_MS:.3f}ms"
|
||||
cpu_time_total_str_ms = f"{event.cpu_time_total / US_IN_MS:.3f}ms"
|
||||
self.assertTrue(cpu_time_str_ms in profiler_output)
|
||||
self.assertTrue(cpu_time_total_str_ms in profiler_output)
|
||||
|
||||
profiler_output = prof.key_averages().table(time_unit="us")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*")
|
||||
self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*")
|
||||
self.assertRegex(profiler_output, r".*(\.[0-9]{3}us).*")
|
||||
for event in prof.key_averages():
|
||||
cpu_time_str_us = f"{event.cpu_time:.3f}us"
|
||||
cpu_time_total_str_us = f"{event.cpu_time_total:.3f}us"
|
||||
self.assertTrue(cpu_time_str_us in profiler_output)
|
||||
self.assertTrue(cpu_time_total_str_us in profiler_output)
|
||||
|
||||
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
|
||||
@patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
|
||||
def test_kineto_profiler_api(self):
|
||||
|
|
|
|||
|
|
@ -173,6 +173,7 @@ class EventList(list):
|
|||
max_shapes_column_width=80,
|
||||
header=None,
|
||||
top_level_events_only=False,
|
||||
time_unit=None,
|
||||
):
|
||||
"""Print an EventList as a nicely formatted table.
|
||||
|
||||
|
|
@ -189,6 +190,8 @@ class EventList(list):
|
|||
display events at top level like top-level invocation of python
|
||||
`lstm`, python `add` or other functions, nested events like low-level
|
||||
cpu/cuda/xpu ops events are omitted for profiler result readability.
|
||||
time_unit(str, optional): A time unit to be used for all values in the
|
||||
table. Valid options are: ``s``, ``ms`` and ``us``.
|
||||
|
||||
Returns:
|
||||
A string containing the table.
|
||||
|
|
@ -204,6 +207,7 @@ class EventList(list):
|
|||
profile_memory=self._profile_memory,
|
||||
with_flops=self._with_flops,
|
||||
top_level_events_only=top_level_events_only,
|
||||
time_unit=time_unit,
|
||||
)
|
||||
|
||||
def export_chrome_trace(self, path):
|
||||
|
|
@ -832,6 +836,7 @@ def _build_table(
|
|||
with_flops=False,
|
||||
profile_memory=False,
|
||||
top_level_events_only=False,
|
||||
time_unit=None,
|
||||
):
|
||||
"""Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
|
||||
if len(events) == 0:
|
||||
|
|
@ -1039,6 +1044,18 @@ def _build_table(
|
|||
path = "..." + path[3:]
|
||||
return path
|
||||
|
||||
def override_time_unit(time_us, default_str, time_unit):
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
if time_unit == "s":
|
||||
return f"{time_us / US_IN_SECOND:.3f}s"
|
||||
elif time_unit == "ms":
|
||||
return f"{time_us / US_IN_MS:.3f}ms"
|
||||
elif time_unit == "us":
|
||||
return f"{time_us:.3f}us"
|
||||
else:
|
||||
return default_str
|
||||
|
||||
event_limit = 0
|
||||
for evt in events:
|
||||
if event_limit == row_limit:
|
||||
|
|
@ -1072,11 +1089,17 @@ def _build_table(
|
|||
row_values += [
|
||||
# Self CPU total %, 0 for async events.
|
||||
evt.self_cpu_percent,
|
||||
evt.self_cpu_time_total_str, # Self CPU total
|
||||
override_time_unit(
|
||||
evt.self_cpu_time_total, evt.self_cpu_time_total_str, time_unit
|
||||
), # Self CPU total
|
||||
# CPU total %, 0 for async events.
|
||||
evt.total_cpu_percent,
|
||||
evt.cpu_time_total_str, # CPU total
|
||||
evt.cpu_time_str, # CPU time avg
|
||||
override_time_unit(
|
||||
evt.cpu_time_total, evt.cpu_time_total_str, time_unit
|
||||
), # CPU total
|
||||
override_time_unit(
|
||||
evt.cpu_time, evt.cpu_time_str, time_unit
|
||||
), # CPU time avg
|
||||
]
|
||||
if has_device_time:
|
||||
evt.total_device_percent = _format_time_share(
|
||||
|
|
@ -1084,11 +1107,19 @@ def _build_table(
|
|||
)
|
||||
row_values.extend(
|
||||
[
|
||||
evt.self_device_time_total_str,
|
||||
override_time_unit(
|
||||
evt.self_device_time_total,
|
||||
evt.self_device_time_total_str,
|
||||
time_unit,
|
||||
),
|
||||
# device time total %
|
||||
evt.total_device_percent,
|
||||
evt.device_time_total_str,
|
||||
evt.device_time_str, # device time avg
|
||||
override_time_unit(
|
||||
evt.device_time_total, evt.device_time_total_str, time_unit
|
||||
),
|
||||
override_time_unit(
|
||||
evt.device_time, evt.device_time_str, time_unit
|
||||
), # device time avg
|
||||
]
|
||||
)
|
||||
if profile_memory:
|
||||
|
|
@ -1141,10 +1172,12 @@ def _build_table(
|
|||
append(row_format.format(*empty_headers))
|
||||
|
||||
append(header_sep)
|
||||
append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}")
|
||||
append(
|
||||
f"Self CPU time total: {override_time_unit(sum_self_cpu_time_total, _format_time(sum_self_cpu_time_total), time_unit)}"
|
||||
)
|
||||
if has_device_time:
|
||||
append(
|
||||
f"Self {use_device.upper() if use_device is not None else 'None'} "
|
||||
f"time total: {_format_time(sum_self_device_time_total)}"
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user