Make a plugin that serves data for the audio dashboard.

Subsequent changes will make TensorBoard use this audio plugin instead of the previous handlers for audio-related data.

PiperOrigin-RevId: 157673132
This commit is contained in:
A. Unique TensorFlower 2017-05-31 21:13:46 -07:00 committed by TensorFlower Gardener
parent 24623653b2
commit 25bb504ccd
7 changed files with 345 additions and 1 deletions

View File

@ -379,6 +379,7 @@ filegroup(
"//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/audio:all_files",
"//tensorflow/tensorboard/plugins/histograms:all_files",
"//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",

View File

@ -229,6 +229,7 @@ add_python_module("tensorflow/tensorboard")
add_python_module("tensorflow/tensorboard/backend")
add_python_module("tensorflow/tensorboard/backend/event_processing")
add_python_module("tensorflow/tensorboard/plugins")
add_python_module("tensorflow/tensorboard/plugins/audio")
add_python_module("tensorflow/tensorboard/plugins/histograms")
add_python_module("tensorflow/tensorboard/plugins/images")
add_python_module("tensorflow/tensorboard/plugins/projector")

View File

@ -206,10 +206,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Broken TensorBoard tests due to different paths in windows
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.

View File

@ -13,6 +13,7 @@ py_binary(
deps = [
"//tensorflow/tensorboard/backend:application",
"//tensorflow/tensorboard/backend/event_processing:event_file_inspector",
"//tensorflow/tensorboard/plugins/audio:audio_plugin",
"//tensorflow/tensorboard/plugins/histograms:histograms_plugin",
"//tensorflow/tensorboard/plugins/images:images_plugin",
"//tensorflow/tensorboard/plugins/projector:projector_plugin",

View File

@ -0,0 +1,48 @@
# Description:
# TensorBoard plugin for audio
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "audio_plugin",
srcs = ["audio_plugin.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/tensorboard/backend:http_util",
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
"//tensorflow/tensorboard/plugins:base_plugin",
"@org_pocoo_werkzeug//:werkzeug",
"@six_archive//:six",
],
)
py_test(
name = "audio_plugin_test",
size = "small",
srcs = ["audio_plugin_test.py"],
srcs_version = "PY2AND3",
deps = [
":audio_plugin",
"//tensorflow:tensorflow_py",
"//tensorflow/tensorboard/backend:application",
"//tensorflow/tensorboard/backend/event_processing:event_multiplexer",
"@org_pocoo_werkzeug//:werkzeug",
"@six_archive//:six",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
visibility = ["//tensorflow:__pkg__"],
)

View File

@ -0,0 +1,135 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Audio plugin."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import urllib
from werkzeug import wrappers
from tensorflow.tensorboard.backend import http_util
from tensorflow.tensorboard.backend.event_processing import event_accumulator
from tensorflow.tensorboard.plugins import base_plugin
_PLUGIN_PREFIX_ROUTE = event_accumulator.AUDIO
class AudioPlugin(base_plugin.TBPlugin):
"""Audio Plugin for TensorBoard."""
plugin_name = _PLUGIN_PREFIX_ROUTE
def get_plugin_apps(self, multiplexer, unused_logdir):
self._multiplexer = multiplexer
return {
'/audio': self._serve_audio_metadata,
'/individualAudio': self._serve_individual_audio,
'/tags': self._serve_tags,
}
def is_active(self):
"""The audio plugin is active iff any run has at least one relevant tag."""
return any(self.index_impl().values())
def _index_impl(self):
return {
run_name: run_data[event_accumulator.AUDIO]
for (run_name, run_data) in self._multiplexer.Runs().items()
if event_accumulator.AUDIO in run_data
}
@wrappers.Request.application
def _serve_audio_metadata(self, request):
"""Given a tag and list of runs, serve a list of metadata for audio.
Note that the audio themselves are not sent; instead, we respond with URLs
to the audio. The frontend should treat these URLs as opaque and should not
try to parse information about them or generate them itself, as the format
may change.
Args:
request: A werkzeug.wrappers.Request object.
Returns:
A werkzeug.Response application.
"""
tag = request.args.get('tag')
run = request.args.get('run')
audio_list = self._multiplexer.Audio(run, tag)
response = self._audio_response_for_run(audio_list, run, tag)
return http_util.Respond(request, response, 'application/json')
def _audio_response_for_run(self, run_audio, run, tag):
"""Builds a JSON-serializable object with information about run_audio.
Args:
run_audio: A list of event_accumulator.AudioValueEvent objects.
run: The name of the run.
tag: The name of the tag the audio entries all belong to.
Returns:
A list of dictionaries containing the wall time, step, URL, width, and
height for each audio entry.
"""
response = []
for index, run_audio_clip in enumerate(run_audio):
response.append({
'wall_time': run_audio_clip.wall_time,
'step': run_audio_clip.step,
'content_type': run_audio_clip.content_type,
'query': self._query_for_individual_audio(run, tag, index)
})
return response
def _query_for_individual_audio(self, run, tag, index):
"""Builds a URL for accessing the specified audio.
This should be kept in sync with _serve_audio_metadata. Note that the URL is
*not* guaranteed to always return the same audio, since audio may be
unloaded from the reservoir as new audio entries come in.
Args:
run: The name of the run.
tag: The tag.
index: The index of the audio entry. Negative values are OK.
Returns:
A string representation of a URL that will load the index-th sampled audio
in the given run with the given tag.
"""
query_string = urllib.parse.urlencode({
'run': run,
'tag': tag,
'index': index
})
return query_string
@wrappers.Request.application
def _serve_individual_audio(self, request):
"""Serves an individual audio entry."""
tag = request.args.get('tag')
run = request.args.get('run')
index = int(request.args.get('index'))
audio = self._multiplexer.Audio(run, tag)[index]
return http_util.Respond(
request, audio.encoded_audio_string, audio.content_type)
@wrappers.Request.application
def _serve_tags(self, request):
index = self._index_impl()
return http_util.Respond(request, index, 'application/json')

View File

@ -0,0 +1,157 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests the Tensorboard audio plugin."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
import shutil
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from werkzeug import test as werkzeug_test
from werkzeug import wrappers
from tensorflow.tensorboard.backend import application
from tensorflow.tensorboard.backend.event_processing import event_multiplexer
from tensorflow.tensorboard.plugins.audio import audio_plugin
class AudioPluginTest(tf.test.TestCase):
def setUp(self):
self.log_dir = tempfile.mkdtemp()
# We use numpy.random to generate audio. We seed to avoid non-determinism
# in this test.
numpy.random.seed(42)
# Create audio summaries for run foo.
tf.reset_default_graph()
sess = tf.Session()
placeholder = tf.placeholder(tf.float32)
tf.summary.audio(name="baz", tensor=placeholder, sample_rate=44100)
merged_summary_op = tf.summary.merge_all()
foo_directory = os.path.join(self.log_dir, "foo")
writer = tf.summary.FileWriter(foo_directory)
writer.add_graph(sess.graph)
for step in xrange(2):
# The floats (sample data) range from -1 to 1.
writer.add_summary(sess.run(merged_summary_op, feed_dict={
placeholder: numpy.random.rand(42, 22050) * 2 - 1
}), global_step=step)
writer.close()
# Create audio summaries for run bar.
tf.reset_default_graph()
sess = tf.Session()
placeholder = tf.placeholder(tf.float32)
tf.summary.audio(name="quux", tensor=placeholder, sample_rate=44100)
merged_summary_op = tf.summary.merge_all()
bar_directory = os.path.join(self.log_dir, "bar")
writer = tf.summary.FileWriter(bar_directory)
writer.add_graph(sess.graph)
for step in xrange(2):
# The floats (sample data) range from -1 to 1.
writer.add_summary(sess.run(merged_summary_op, feed_dict={
placeholder: numpy.random.rand(42, 11025) * 2 - 1
}), global_step=step)
writer.close()
# Start a server with the plugin.
multiplexer = event_multiplexer.EventMultiplexer({
"foo": foo_directory,
"bar": bar_directory,
})
plugin = audio_plugin.AudioPlugin()
wsgi_app = application.TensorBoardWSGIApp(
self.log_dir, [plugin], multiplexer, reload_interval=0)
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
self.routes = plugin.get_plugin_apps(multiplexer, self.log_dir)
def tearDown(self):
shutil.rmtree(self.log_dir, ignore_errors=True)
def _DeserializeResponse(self, byte_content):
"""Deserializes byte content that is a JSON encoding.
Args:
byte_content: The byte content of a response.
Returns:
The deserialized python object decoded from JSON.
"""
return json.loads(byte_content.decode("utf-8"))
def testRoutesProvided(self):
"""Tests that the plugin offers the correct routes."""
self.assertIsInstance(self.routes["/audio"], collections.Callable)
self.assertIsInstance(self.routes["/individualAudio"], collections.Callable)
self.assertIsInstance(self.routes["/tags"], collections.Callable)
def testAudioRoute(self):
"""Tests that the /audio routes returns with the correct data."""
response = self.server.get(
"/data/plugin/audio/audio?run=foo&tag=baz/audio/0")
self.assertEqual(200, response.status_code)
# Verify that the correct entries are returned.
entries = self._DeserializeResponse(response.get_data())
self.assertEqual(2, len(entries))
# Verify that the 1st entry is correct.
entry = entries[0]
self.assertEqual(0, entry["step"])
parsed_query = urllib.parse.parse_qs(entry["query"])
self.assertListEqual(["0"], parsed_query["index"])
self.assertListEqual(["foo"], parsed_query["run"])
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
# Verify that the 2nd entry is correct.
entry = entries[1]
self.assertEqual(1, entry["step"])
parsed_query = urllib.parse.parse_qs(entry["query"])
self.assertListEqual(["1"], parsed_query["index"])
self.assertListEqual(["foo"], parsed_query["run"])
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
def testIndividualAudioRoute(self):
"""Tests fetching an individual audio."""
response = self.server.get(
"/data/plugin/audio/individualAudio?run=bar&tag=quux/audio/0&index=0")
self.assertEqual(200, response.status_code)
self.assertEqual("audio/wav", response.headers.get("content-type"))
def testRunsRoute(self):
"""Tests that the /runs route offers the correct run to tag mapping."""
response = self.server.get("/data/plugin/audio/tags")
self.assertEqual(200, response.status_code)
run_to_tags = self._DeserializeResponse(response.get_data())
self.assertItemsEqual(("foo", "bar"), run_to_tags.keys())
self.assertItemsEqual(
["baz/audio/0", "baz/audio/1", "baz/audio/2"], run_to_tags["foo"])
self.assertItemsEqual(
["quux/audio/0", "quux/audio/1", "quux/audio/2"], run_to_tags["bar"])
if __name__ == "__main__":
tf.test.main()