torch/monitor: TensorboardEventHandler (#71658)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71658

This adds the beginnings of a TensorboardEventHandler which will log stats to Tensorboard.

Test Plan: buck test //caffe2/test:monitor

Reviewed By: edward-io

Differential Revision: D33719954

fbshipit-source-id: e9847c1319255ce0d9cf2d85d8b54b7a3c681bd2
(cherry picked from commit 5c8520a6ba)
This commit is contained in:
Tristan Rice 2022-01-26 23:39:20 -08:00 committed by PyTorch MergeBot
parent d4d0ab71b3
commit 7aa4a1f63e
3 changed files with 92 additions and 1 deletions

View File

@ -24,7 +24,6 @@ API Reference
-------------
.. automodule:: torch.monitor
:members:
.. autoclass:: torch.monitor.Aggregation
:members:
@ -55,3 +54,7 @@ API Reference
.. autofunction:: torch.monitor.register_event_handler
.. autofunction:: torch.monitor.unregister_event_handler
.. autoclass:: torch.monitor.TensorboardEventHandler
:members:
:special-members: __init__

View File

@ -5,6 +5,7 @@ from torch.testing._internal.common_utils import (
)
from datetime import timedelta, datetime
import tempfile
import time
from torch.monitor import (
@ -16,6 +17,7 @@ from torch.monitor import (
register_event_handler,
unregister_event_handler,
Stat,
TensorboardEventHandler,
)
class TestMonitor(TestCase):
@ -98,6 +100,60 @@ class TestMonitor(TestCase):
log_event(e)
self.assertEqual(len(events), 2)
class TestMonitorTensorboard(TestCase):
def setUp(self):
global SummaryWriter, event_multiplexer
try:
from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
except ImportError:
return self.skipTest("Skip the test since TensorBoard is not installed")
self.temp_dirs = []
def create_summary_writer(self):
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
self.temp_dirs.append(temp_dir)
return SummaryWriter(temp_dir.name)
def tearDown(self):
# Remove directories created by SummaryWriter
for temp_dir in self.temp_dirs:
temp_dir.cleanup()
def test_event_handler(self):
with self.create_summary_writer() as w:
handle = register_event_handler(TensorboardEventHandler(w))
s = FixedCountStat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
2,
)
for i in range(10):
s.add(i)
self.assertEqual(s.count, 0)
unregister_event_handler(handle)
mul = event_multiplexer.EventMultiplexer()
mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
mul.Reload()
scalar_dict = mul.PluginRunToTagToContent("scalars")
raw_result = {
tag: mul.Tensors(run, tag)
for run, run_dict in scalar_dict.items()
for tag in run_dict
}
scalars = {
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
}
self.assertEqual(scalars, {
"asdf.sum": [1, 5, 9, 13, 17],
"asdf.count": [2, 2, 2, 2, 2],
})
if __name__ == '__main__':
run_tests()

View File

@ -1 +1,33 @@
from torch._C._monitor import * # noqa: F403
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.utils.tensorboard import SummaryWriter
STAT_EVENT = "torch.monitor.Stat"
class TensorboardEventHandler:
"""
TensorboardEventHandler is an event handler that will write known events to
the provided SummaryWriter.
This currently only supports ``torch.monitor.Stat`` events which are logged
as scalars.
>>> from torch.utils.tensorboard import SummaryWriter
>>> from torch.monitor import TensorboardEventHandler, register_event_handler
>>> writer = SummaryWriter("log_dir")
>>> register_event_handler(TensorboardEventHandler(writer))
"""
def __init__(self, writer: "SummaryWriter") -> None:
"""
Constructs the ``TensorboardEventHandler``.
"""
self._writer = writer
def __call__(self, event: Event) -> None:
if event.name == STAT_EVENT:
for k, v in event.data.items():
self._writer.add_scalar(k, v, walltime=event.timestamp.timestamp())