mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
torch/monitor: TensorboardEventHandler (#71658)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71658
This adds the beginnings of a TensorboardEventHandler which will log stats to Tensorboard.
Test Plan: buck test //caffe2/test:monitor
Reviewed By: edward-io
Differential Revision: D33719954
fbshipit-source-id: e9847c1319255ce0d9cf2d85d8b54b7a3c681bd2
(cherry picked from commit 5c8520a6ba)
This commit is contained in:
parent
d4d0ab71b3
commit
7aa4a1f63e
|
|
@ -24,7 +24,6 @@ API Reference
|
|||
-------------
|
||||
|
||||
.. automodule:: torch.monitor
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.Aggregation
|
||||
:members:
|
||||
|
|
@ -55,3 +54,7 @@ API Reference
|
|||
.. autofunction:: torch.monitor.register_event_handler
|
||||
|
||||
.. autofunction:: torch.monitor.unregister_event_handler
|
||||
|
||||
.. autoclass:: torch.monitor.TensorboardEventHandler
|
||||
:members:
|
||||
:special-members: __init__
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch.monitor import (
|
||||
|
|
@ -16,6 +17,7 @@ from torch.monitor import (
|
|||
register_event_handler,
|
||||
unregister_event_handler,
|
||||
Stat,
|
||||
TensorboardEventHandler,
|
||||
)
|
||||
|
||||
class TestMonitor(TestCase):
|
||||
|
|
@ -98,6 +100,60 @@ class TestMonitor(TestCase):
|
|||
log_event(e)
|
||||
self.assertEqual(len(events), 2)
|
||||
|
||||
class TestMonitorTensorboard(TestCase):
|
||||
def setUp(self):
|
||||
global SummaryWriter, event_multiplexer
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tensorboard.backend.event_processing import (
|
||||
plugin_event_multiplexer as event_multiplexer,
|
||||
)
|
||||
except ImportError:
|
||||
return self.skipTest("Skip the test since TensorBoard is not installed")
|
||||
self.temp_dirs = []
|
||||
|
||||
def create_summary_writer(self):
|
||||
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
|
||||
self.temp_dirs.append(temp_dir)
|
||||
return SummaryWriter(temp_dir.name)
|
||||
|
||||
def tearDown(self):
|
||||
# Remove directories created by SummaryWriter
|
||||
for temp_dir in self.temp_dirs:
|
||||
temp_dir.cleanup()
|
||||
|
||||
def test_event_handler(self):
|
||||
with self.create_summary_writer() as w:
|
||||
handle = register_event_handler(TensorboardEventHandler(w))
|
||||
|
||||
s = FixedCountStat(
|
||||
"asdf",
|
||||
(Aggregation.SUM, Aggregation.COUNT),
|
||||
2,
|
||||
)
|
||||
for i in range(10):
|
||||
s.add(i)
|
||||
self.assertEqual(s.count, 0)
|
||||
|
||||
unregister_event_handler(handle)
|
||||
|
||||
mul = event_multiplexer.EventMultiplexer()
|
||||
mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
|
||||
mul.Reload()
|
||||
scalar_dict = mul.PluginRunToTagToContent("scalars")
|
||||
raw_result = {
|
||||
tag: mul.Tensors(run, tag)
|
||||
for run, run_dict in scalar_dict.items()
|
||||
for tag in run_dict
|
||||
}
|
||||
scalars = {
|
||||
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
|
||||
}
|
||||
self.assertEqual(scalars, {
|
||||
"asdf.sum": [1, 5, 9, 13, 17],
|
||||
"asdf.count": [2, 2, 2, 2, 2],
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1 +1,33 @@
|
|||
from torch._C._monitor import * # noqa: F403
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
STAT_EVENT = "torch.monitor.Stat"
|
||||
|
||||
class TensorboardEventHandler:
|
||||
"""
|
||||
TensorboardEventHandler is an event handler that will write known events to
|
||||
the provided SummaryWriter.
|
||||
|
||||
This currently only supports ``torch.monitor.Stat`` events which are logged
|
||||
as scalars.
|
||||
|
||||
>>> from torch.utils.tensorboard import SummaryWriter
|
||||
>>> from torch.monitor import TensorboardEventHandler, register_event_handler
|
||||
>>> writer = SummaryWriter("log_dir")
|
||||
>>> register_event_handler(TensorboardEventHandler(writer))
|
||||
"""
|
||||
def __init__(self, writer: "SummaryWriter") -> None:
|
||||
"""
|
||||
Constructs the ``TensorboardEventHandler``.
|
||||
"""
|
||||
self._writer = writer
|
||||
|
||||
def __call__(self, event: Event) -> None:
|
||||
if event.name == STAT_EVENT:
|
||||
for k, v in event.data.items():
|
||||
self._writer.add_scalar(k, v, walltime=event.timestamp.timestamp())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user