mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make a plugin that serves data for the audio dashboard.
Subsequent changes will make TensorBoard use this audio plugin instead of the previous handlers for audio-related data. PiperOrigin-RevId: 157673132
This commit is contained in:
parent
24623653b2
commit
25bb504ccd
|
|
@ -379,6 +379,7 @@ filegroup(
|
||||||
"//tensorflow/tensorboard/demo:all_files",
|
"//tensorflow/tensorboard/demo:all_files",
|
||||||
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
||||||
"//tensorflow/tensorboard/plugins:all_files",
|
"//tensorflow/tensorboard/plugins:all_files",
|
||||||
|
"//tensorflow/tensorboard/plugins/audio:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/images:all_files",
|
"//tensorflow/tensorboard/plugins/images:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,7 @@ add_python_module("tensorflow/tensorboard")
|
||||||
add_python_module("tensorflow/tensorboard/backend")
|
add_python_module("tensorflow/tensorboard/backend")
|
||||||
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
||||||
add_python_module("tensorflow/tensorboard/plugins")
|
add_python_module("tensorflow/tensorboard/plugins")
|
||||||
|
add_python_module("tensorflow/tensorboard/plugins/audio")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/images")
|
add_python_module("tensorflow/tensorboard/plugins/images")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/projector")
|
add_python_module("tensorflow/tensorboard/plugins/projector")
|
||||||
|
|
|
||||||
|
|
@ -206,10 +206,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||||
# Broken TensorBoard tests due to different paths in windows
|
# Broken TensorBoard tests due to different paths in windows
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||||
# Broken tensorboard test due to cmake issues.
|
# Broken tensorboard test due to cmake issues.
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
|
||||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ py_binary(
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/tensorboard/backend:application",
|
"//tensorflow/tensorboard/backend:application",
|
||||||
"//tensorflow/tensorboard/backend/event_processing:event_file_inspector",
|
"//tensorflow/tensorboard/backend/event_processing:event_file_inspector",
|
||||||
|
"//tensorflow/tensorboard/plugins/audio:audio_plugin",
|
||||||
"//tensorflow/tensorboard/plugins/histograms:histograms_plugin",
|
"//tensorflow/tensorboard/plugins/histograms:histograms_plugin",
|
||||||
"//tensorflow/tensorboard/plugins/images:images_plugin",
|
"//tensorflow/tensorboard/plugins/images:images_plugin",
|
||||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||||
|
|
|
||||||
48
tensorflow/tensorboard/plugins/audio/BUILD
Normal file
48
tensorflow/tensorboard/plugins/audio/BUILD
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Description:
|
||||||
|
# TensorBoard plugin for audio
|
||||||
|
|
||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_plugin",
|
||||||
|
srcs = ["audio_plugin.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/tensorboard/backend:http_util",
|
||||||
|
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
|
||||||
|
"//tensorflow/tensorboard/plugins:base_plugin",
|
||||||
|
"@org_pocoo_werkzeug//:werkzeug",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "audio_plugin_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["audio_plugin_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":audio_plugin",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/tensorboard/backend:application",
|
||||||
|
"//tensorflow/tensorboard/backend/event_processing:event_multiplexer",
|
||||||
|
"@org_pocoo_werkzeug//:werkzeug",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
visibility = ["//tensorflow:__pkg__"],
|
||||||
|
)
|
||||||
135
tensorflow/tensorboard/plugins/audio/audio_plugin.py
Normal file
135
tensorflow/tensorboard/plugins/audio/audio_plugin.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The TensorBoard Audio plugin."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from six.moves import urllib
|
||||||
|
from werkzeug import wrappers
|
||||||
|
|
||||||
|
from tensorflow.tensorboard.backend import http_util
|
||||||
|
from tensorflow.tensorboard.backend.event_processing import event_accumulator
|
||||||
|
from tensorflow.tensorboard.plugins import base_plugin
|
||||||
|
|
||||||
|
_PLUGIN_PREFIX_ROUTE = event_accumulator.AUDIO
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPlugin(base_plugin.TBPlugin):
|
||||||
|
"""Audio Plugin for TensorBoard."""
|
||||||
|
|
||||||
|
plugin_name = _PLUGIN_PREFIX_ROUTE
|
||||||
|
|
||||||
|
def get_plugin_apps(self, multiplexer, unused_logdir):
|
||||||
|
self._multiplexer = multiplexer
|
||||||
|
return {
|
||||||
|
'/audio': self._serve_audio_metadata,
|
||||||
|
'/individualAudio': self._serve_individual_audio,
|
||||||
|
'/tags': self._serve_tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_active(self):
|
||||||
|
"""The audio plugin is active iff any run has at least one relevant tag."""
|
||||||
|
return any(self.index_impl().values())
|
||||||
|
|
||||||
|
def _index_impl(self):
|
||||||
|
return {
|
||||||
|
run_name: run_data[event_accumulator.AUDIO]
|
||||||
|
for (run_name, run_data) in self._multiplexer.Runs().items()
|
||||||
|
if event_accumulator.AUDIO in run_data
|
||||||
|
}
|
||||||
|
|
||||||
|
@wrappers.Request.application
|
||||||
|
def _serve_audio_metadata(self, request):
|
||||||
|
"""Given a tag and list of runs, serve a list of metadata for audio.
|
||||||
|
|
||||||
|
Note that the audio themselves are not sent; instead, we respond with URLs
|
||||||
|
to the audio. The frontend should treat these URLs as opaque and should not
|
||||||
|
try to parse information about them or generate them itself, as the format
|
||||||
|
may change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A werkzeug.wrappers.Request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A werkzeug.Response application.
|
||||||
|
"""
|
||||||
|
tag = request.args.get('tag')
|
||||||
|
run = request.args.get('run')
|
||||||
|
|
||||||
|
audio_list = self._multiplexer.Audio(run, tag)
|
||||||
|
response = self._audio_response_for_run(audio_list, run, tag)
|
||||||
|
return http_util.Respond(request, response, 'application/json')
|
||||||
|
|
||||||
|
def _audio_response_for_run(self, run_audio, run, tag):
|
||||||
|
"""Builds a JSON-serializable object with information about run_audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_audio: A list of event_accumulator.AudioValueEvent objects.
|
||||||
|
run: The name of the run.
|
||||||
|
tag: The name of the tag the audio entries all belong to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries containing the wall time, step, URL, width, and
|
||||||
|
height for each audio entry.
|
||||||
|
"""
|
||||||
|
response = []
|
||||||
|
for index, run_audio_clip in enumerate(run_audio):
|
||||||
|
response.append({
|
||||||
|
'wall_time': run_audio_clip.wall_time,
|
||||||
|
'step': run_audio_clip.step,
|
||||||
|
'content_type': run_audio_clip.content_type,
|
||||||
|
'query': self._query_for_individual_audio(run, tag, index)
|
||||||
|
})
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _query_for_individual_audio(self, run, tag, index):
|
||||||
|
"""Builds a URL for accessing the specified audio.
|
||||||
|
|
||||||
|
This should be kept in sync with _serve_audio_metadata. Note that the URL is
|
||||||
|
*not* guaranteed to always return the same audio, since audio may be
|
||||||
|
unloaded from the reservoir as new audio entries come in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: The name of the run.
|
||||||
|
tag: The tag.
|
||||||
|
index: The index of the audio entry. Negative values are OK.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string representation of a URL that will load the index-th sampled audio
|
||||||
|
in the given run with the given tag.
|
||||||
|
"""
|
||||||
|
query_string = urllib.parse.urlencode({
|
||||||
|
'run': run,
|
||||||
|
'tag': tag,
|
||||||
|
'index': index
|
||||||
|
})
|
||||||
|
return query_string
|
||||||
|
|
||||||
|
@wrappers.Request.application
|
||||||
|
def _serve_individual_audio(self, request):
|
||||||
|
"""Serves an individual audio entry."""
|
||||||
|
tag = request.args.get('tag')
|
||||||
|
run = request.args.get('run')
|
||||||
|
index = int(request.args.get('index'))
|
||||||
|
audio = self._multiplexer.Audio(run, tag)[index]
|
||||||
|
return http_util.Respond(
|
||||||
|
request, audio.encoded_audio_string, audio.content_type)
|
||||||
|
|
||||||
|
@wrappers.Request.application
|
||||||
|
def _serve_tags(self, request):
|
||||||
|
index = self._index_impl()
|
||||||
|
return http_util.Respond(request, index, 'application/json')
|
||||||
157
tensorflow/tensorboard/plugins/audio/audio_plugin_test.py
Normal file
157
tensorflow/tensorboard/plugins/audio/audio_plugin_test.py
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests the Tensorboard audio plugin."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from six.moves import urllib
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
from werkzeug import test as werkzeug_test
|
||||||
|
from werkzeug import wrappers
|
||||||
|
|
||||||
|
from tensorflow.tensorboard.backend import application
|
||||||
|
from tensorflow.tensorboard.backend.event_processing import event_multiplexer
|
||||||
|
from tensorflow.tensorboard.plugins.audio import audio_plugin
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPluginTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.log_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
# We use numpy.random to generate audio. We seed to avoid non-determinism
|
||||||
|
# in this test.
|
||||||
|
numpy.random.seed(42)
|
||||||
|
|
||||||
|
# Create audio summaries for run foo.
|
||||||
|
tf.reset_default_graph()
|
||||||
|
sess = tf.Session()
|
||||||
|
placeholder = tf.placeholder(tf.float32)
|
||||||
|
tf.summary.audio(name="baz", tensor=placeholder, sample_rate=44100)
|
||||||
|
merged_summary_op = tf.summary.merge_all()
|
||||||
|
foo_directory = os.path.join(self.log_dir, "foo")
|
||||||
|
writer = tf.summary.FileWriter(foo_directory)
|
||||||
|
writer.add_graph(sess.graph)
|
||||||
|
for step in xrange(2):
|
||||||
|
# The floats (sample data) range from -1 to 1.
|
||||||
|
writer.add_summary(sess.run(merged_summary_op, feed_dict={
|
||||||
|
placeholder: numpy.random.rand(42, 22050) * 2 - 1
|
||||||
|
}), global_step=step)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
# Create audio summaries for run bar.
|
||||||
|
tf.reset_default_graph()
|
||||||
|
sess = tf.Session()
|
||||||
|
placeholder = tf.placeholder(tf.float32)
|
||||||
|
tf.summary.audio(name="quux", tensor=placeholder, sample_rate=44100)
|
||||||
|
merged_summary_op = tf.summary.merge_all()
|
||||||
|
bar_directory = os.path.join(self.log_dir, "bar")
|
||||||
|
writer = tf.summary.FileWriter(bar_directory)
|
||||||
|
writer.add_graph(sess.graph)
|
||||||
|
for step in xrange(2):
|
||||||
|
# The floats (sample data) range from -1 to 1.
|
||||||
|
writer.add_summary(sess.run(merged_summary_op, feed_dict={
|
||||||
|
placeholder: numpy.random.rand(42, 11025) * 2 - 1
|
||||||
|
}), global_step=step)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
# Start a server with the plugin.
|
||||||
|
multiplexer = event_multiplexer.EventMultiplexer({
|
||||||
|
"foo": foo_directory,
|
||||||
|
"bar": bar_directory,
|
||||||
|
})
|
||||||
|
plugin = audio_plugin.AudioPlugin()
|
||||||
|
wsgi_app = application.TensorBoardWSGIApp(
|
||||||
|
self.log_dir, [plugin], multiplexer, reload_interval=0)
|
||||||
|
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
|
||||||
|
self.routes = plugin.get_plugin_apps(multiplexer, self.log_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.log_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
def _DeserializeResponse(self, byte_content):
|
||||||
|
"""Deserializes byte content that is a JSON encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_content: The byte content of a response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The deserialized python object decoded from JSON.
|
||||||
|
"""
|
||||||
|
return json.loads(byte_content.decode("utf-8"))
|
||||||
|
|
||||||
|
def testRoutesProvided(self):
|
||||||
|
"""Tests that the plugin offers the correct routes."""
|
||||||
|
self.assertIsInstance(self.routes["/audio"], collections.Callable)
|
||||||
|
self.assertIsInstance(self.routes["/individualAudio"], collections.Callable)
|
||||||
|
self.assertIsInstance(self.routes["/tags"], collections.Callable)
|
||||||
|
|
||||||
|
def testAudioRoute(self):
|
||||||
|
"""Tests that the /audio routes returns with the correct data."""
|
||||||
|
response = self.server.get(
|
||||||
|
"/data/plugin/audio/audio?run=foo&tag=baz/audio/0")
|
||||||
|
self.assertEqual(200, response.status_code)
|
||||||
|
|
||||||
|
# Verify that the correct entries are returned.
|
||||||
|
entries = self._DeserializeResponse(response.get_data())
|
||||||
|
self.assertEqual(2, len(entries))
|
||||||
|
|
||||||
|
# Verify that the 1st entry is correct.
|
||||||
|
entry = entries[0]
|
||||||
|
self.assertEqual(0, entry["step"])
|
||||||
|
parsed_query = urllib.parse.parse_qs(entry["query"])
|
||||||
|
self.assertListEqual(["0"], parsed_query["index"])
|
||||||
|
self.assertListEqual(["foo"], parsed_query["run"])
|
||||||
|
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
|
||||||
|
|
||||||
|
# Verify that the 2nd entry is correct.
|
||||||
|
entry = entries[1]
|
||||||
|
self.assertEqual(1, entry["step"])
|
||||||
|
parsed_query = urllib.parse.parse_qs(entry["query"])
|
||||||
|
self.assertListEqual(["1"], parsed_query["index"])
|
||||||
|
self.assertListEqual(["foo"], parsed_query["run"])
|
||||||
|
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
|
||||||
|
|
||||||
|
def testIndividualAudioRoute(self):
|
||||||
|
"""Tests fetching an individual audio."""
|
||||||
|
response = self.server.get(
|
||||||
|
"/data/plugin/audio/individualAudio?run=bar&tag=quux/audio/0&index=0")
|
||||||
|
self.assertEqual(200, response.status_code)
|
||||||
|
self.assertEqual("audio/wav", response.headers.get("content-type"))
|
||||||
|
|
||||||
|
def testRunsRoute(self):
|
||||||
|
"""Tests that the /runs route offers the correct run to tag mapping."""
|
||||||
|
response = self.server.get("/data/plugin/audio/tags")
|
||||||
|
self.assertEqual(200, response.status_code)
|
||||||
|
run_to_tags = self._DeserializeResponse(response.get_data())
|
||||||
|
self.assertItemsEqual(("foo", "bar"), run_to_tags.keys())
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["baz/audio/0", "baz/audio/1", "baz/audio/2"], run_to_tags["foo"])
|
||||||
|
self.assertItemsEqual(
|
||||||
|
["quux/audio/0", "quux/audio/1", "quux/audio/2"], run_to_tags["bar"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
||||||
Loading…
Reference in New Issue
Block a user