mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Make a plugin that serves data for the audio dashboard.
Subsequent changes will make TensorBoard use this audio plugin instead of the previous handlers for audio-related data. PiperOrigin-RevId: 157673132
This commit is contained in:
parent
24623653b2
commit
25bb504ccd
|
|
@ -379,6 +379,7 @@ filegroup(
|
|||
"//tensorflow/tensorboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
||||
"//tensorflow/tensorboard/plugins:all_files",
|
||||
"//tensorflow/tensorboard/plugins/audio:all_files",
|
||||
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
||||
"//tensorflow/tensorboard/plugins/images:all_files",
|
||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ add_python_module("tensorflow/tensorboard")
|
|||
add_python_module("tensorflow/tensorboard/backend")
|
||||
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
||||
add_python_module("tensorflow/tensorboard/plugins")
|
||||
add_python_module("tensorflow/tensorboard/plugins/audio")
|
||||
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
||||
add_python_module("tensorflow/tensorboard/plugins/images")
|
||||
add_python_module("tensorflow/tensorboard/plugins/projector")
|
||||
|
|
|
|||
|
|
@ -206,10 +206,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
# Broken TensorBoard tests due to different paths in windows
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ py_binary(
|
|||
deps = [
|
||||
"//tensorflow/tensorboard/backend:application",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_file_inspector",
|
||||
"//tensorflow/tensorboard/plugins/audio:audio_plugin",
|
||||
"//tensorflow/tensorboard/plugins/histograms:histograms_plugin",
|
||||
"//tensorflow/tensorboard/plugins/images:images_plugin",
|
||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||
|
|
|
|||
48
tensorflow/tensorboard/plugins/audio/BUILD
Normal file
48
tensorflow/tensorboard/plugins/audio/BUILD
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Description:
|
||||
# TensorBoard plugin for audio
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "audio_plugin",
|
||||
srcs = ["audio_plugin.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/tensorboard/backend:http_util",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
|
||||
"//tensorflow/tensorboard/plugins:base_plugin",
|
||||
"@org_pocoo_werkzeug//:werkzeug",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "audio_plugin_test",
|
||||
size = "small",
|
||||
srcs = ["audio_plugin_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":audio_plugin",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/tensorboard/backend:application",
|
||||
"//tensorflow/tensorboard/backend/event_processing:event_multiplexer",
|
||||
"@org_pocoo_werkzeug//:werkzeug",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
visibility = ["//tensorflow:__pkg__"],
|
||||
)
|
||||
135
tensorflow/tensorboard/plugins/audio/audio_plugin.py
Normal file
135
tensorflow/tensorboard/plugins/audio/audio_plugin.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The TensorBoard Audio plugin."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import urllib
|
||||
from werkzeug import wrappers
|
||||
|
||||
from tensorflow.tensorboard.backend import http_util
|
||||
from tensorflow.tensorboard.backend.event_processing import event_accumulator
|
||||
from tensorflow.tensorboard.plugins import base_plugin
|
||||
|
||||
_PLUGIN_PREFIX_ROUTE = event_accumulator.AUDIO
|
||||
|
||||
|
||||
class AudioPlugin(base_plugin.TBPlugin):
|
||||
"""Audio Plugin for TensorBoard."""
|
||||
|
||||
plugin_name = _PLUGIN_PREFIX_ROUTE
|
||||
|
||||
def get_plugin_apps(self, multiplexer, unused_logdir):
|
||||
self._multiplexer = multiplexer
|
||||
return {
|
||||
'/audio': self._serve_audio_metadata,
|
||||
'/individualAudio': self._serve_individual_audio,
|
||||
'/tags': self._serve_tags,
|
||||
}
|
||||
|
||||
def is_active(self):
|
||||
"""The audio plugin is active iff any run has at least one relevant tag."""
|
||||
return any(self.index_impl().values())
|
||||
|
||||
def _index_impl(self):
|
||||
return {
|
||||
run_name: run_data[event_accumulator.AUDIO]
|
||||
for (run_name, run_data) in self._multiplexer.Runs().items()
|
||||
if event_accumulator.AUDIO in run_data
|
||||
}
|
||||
|
||||
@wrappers.Request.application
|
||||
def _serve_audio_metadata(self, request):
|
||||
"""Given a tag and list of runs, serve a list of metadata for audio.
|
||||
|
||||
Note that the audio themselves are not sent; instead, we respond with URLs
|
||||
to the audio. The frontend should treat these URLs as opaque and should not
|
||||
try to parse information about them or generate them itself, as the format
|
||||
may change.
|
||||
|
||||
Args:
|
||||
request: A werkzeug.wrappers.Request object.
|
||||
|
||||
Returns:
|
||||
A werkzeug.Response application.
|
||||
"""
|
||||
tag = request.args.get('tag')
|
||||
run = request.args.get('run')
|
||||
|
||||
audio_list = self._multiplexer.Audio(run, tag)
|
||||
response = self._audio_response_for_run(audio_list, run, tag)
|
||||
return http_util.Respond(request, response, 'application/json')
|
||||
|
||||
def _audio_response_for_run(self, run_audio, run, tag):
|
||||
"""Builds a JSON-serializable object with information about run_audio.
|
||||
|
||||
Args:
|
||||
run_audio: A list of event_accumulator.AudioValueEvent objects.
|
||||
run: The name of the run.
|
||||
tag: The name of the tag the audio entries all belong to.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries containing the wall time, step, URL, width, and
|
||||
height for each audio entry.
|
||||
"""
|
||||
response = []
|
||||
for index, run_audio_clip in enumerate(run_audio):
|
||||
response.append({
|
||||
'wall_time': run_audio_clip.wall_time,
|
||||
'step': run_audio_clip.step,
|
||||
'content_type': run_audio_clip.content_type,
|
||||
'query': self._query_for_individual_audio(run, tag, index)
|
||||
})
|
||||
return response
|
||||
|
||||
def _query_for_individual_audio(self, run, tag, index):
|
||||
"""Builds a URL for accessing the specified audio.
|
||||
|
||||
This should be kept in sync with _serve_audio_metadata. Note that the URL is
|
||||
*not* guaranteed to always return the same audio, since audio may be
|
||||
unloaded from the reservoir as new audio entries come in.
|
||||
|
||||
Args:
|
||||
run: The name of the run.
|
||||
tag: The tag.
|
||||
index: The index of the audio entry. Negative values are OK.
|
||||
|
||||
Returns:
|
||||
A string representation of a URL that will load the index-th sampled audio
|
||||
in the given run with the given tag.
|
||||
"""
|
||||
query_string = urllib.parse.urlencode({
|
||||
'run': run,
|
||||
'tag': tag,
|
||||
'index': index
|
||||
})
|
||||
return query_string
|
||||
|
||||
@wrappers.Request.application
|
||||
def _serve_individual_audio(self, request):
|
||||
"""Serves an individual audio entry."""
|
||||
tag = request.args.get('tag')
|
||||
run = request.args.get('run')
|
||||
index = int(request.args.get('index'))
|
||||
audio = self._multiplexer.Audio(run, tag)[index]
|
||||
return http_util.Respond(
|
||||
request, audio.encoded_audio_string, audio.content_type)
|
||||
|
||||
@wrappers.Request.application
|
||||
def _serve_tags(self, request):
|
||||
index = self._index_impl()
|
||||
return http_util.Respond(request, index, 'application/json')
|
||||
157
tensorflow/tensorboard/plugins/audio/audio_plugin_test.py
Normal file
157
tensorflow/tensorboard/plugins/audio/audio_plugin_test.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests the Tensorboard audio plugin."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
from werkzeug import test as werkzeug_test
|
||||
from werkzeug import wrappers
|
||||
|
||||
from tensorflow.tensorboard.backend import application
|
||||
from tensorflow.tensorboard.backend.event_processing import event_multiplexer
|
||||
from tensorflow.tensorboard.plugins.audio import audio_plugin
|
||||
|
||||
|
||||
class AudioPluginTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.log_dir = tempfile.mkdtemp()
|
||||
|
||||
# We use numpy.random to generate audio. We seed to avoid non-determinism
|
||||
# in this test.
|
||||
numpy.random.seed(42)
|
||||
|
||||
# Create audio summaries for run foo.
|
||||
tf.reset_default_graph()
|
||||
sess = tf.Session()
|
||||
placeholder = tf.placeholder(tf.float32)
|
||||
tf.summary.audio(name="baz", tensor=placeholder, sample_rate=44100)
|
||||
merged_summary_op = tf.summary.merge_all()
|
||||
foo_directory = os.path.join(self.log_dir, "foo")
|
||||
writer = tf.summary.FileWriter(foo_directory)
|
||||
writer.add_graph(sess.graph)
|
||||
for step in xrange(2):
|
||||
# The floats (sample data) range from -1 to 1.
|
||||
writer.add_summary(sess.run(merged_summary_op, feed_dict={
|
||||
placeholder: numpy.random.rand(42, 22050) * 2 - 1
|
||||
}), global_step=step)
|
||||
writer.close()
|
||||
|
||||
# Create audio summaries for run bar.
|
||||
tf.reset_default_graph()
|
||||
sess = tf.Session()
|
||||
placeholder = tf.placeholder(tf.float32)
|
||||
tf.summary.audio(name="quux", tensor=placeholder, sample_rate=44100)
|
||||
merged_summary_op = tf.summary.merge_all()
|
||||
bar_directory = os.path.join(self.log_dir, "bar")
|
||||
writer = tf.summary.FileWriter(bar_directory)
|
||||
writer.add_graph(sess.graph)
|
||||
for step in xrange(2):
|
||||
# The floats (sample data) range from -1 to 1.
|
||||
writer.add_summary(sess.run(merged_summary_op, feed_dict={
|
||||
placeholder: numpy.random.rand(42, 11025) * 2 - 1
|
||||
}), global_step=step)
|
||||
writer.close()
|
||||
|
||||
# Start a server with the plugin.
|
||||
multiplexer = event_multiplexer.EventMultiplexer({
|
||||
"foo": foo_directory,
|
||||
"bar": bar_directory,
|
||||
})
|
||||
plugin = audio_plugin.AudioPlugin()
|
||||
wsgi_app = application.TensorBoardWSGIApp(
|
||||
self.log_dir, [plugin], multiplexer, reload_interval=0)
|
||||
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
|
||||
self.routes = plugin.get_plugin_apps(multiplexer, self.log_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.log_dir, ignore_errors=True)
|
||||
|
||||
def _DeserializeResponse(self, byte_content):
|
||||
"""Deserializes byte content that is a JSON encoding.
|
||||
|
||||
Args:
|
||||
byte_content: The byte content of a response.
|
||||
|
||||
Returns:
|
||||
The deserialized python object decoded from JSON.
|
||||
"""
|
||||
return json.loads(byte_content.decode("utf-8"))
|
||||
|
||||
def testRoutesProvided(self):
|
||||
"""Tests that the plugin offers the correct routes."""
|
||||
self.assertIsInstance(self.routes["/audio"], collections.Callable)
|
||||
self.assertIsInstance(self.routes["/individualAudio"], collections.Callable)
|
||||
self.assertIsInstance(self.routes["/tags"], collections.Callable)
|
||||
|
||||
def testAudioRoute(self):
|
||||
"""Tests that the /audio routes returns with the correct data."""
|
||||
response = self.server.get(
|
||||
"/data/plugin/audio/audio?run=foo&tag=baz/audio/0")
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
# Verify that the correct entries are returned.
|
||||
entries = self._DeserializeResponse(response.get_data())
|
||||
self.assertEqual(2, len(entries))
|
||||
|
||||
# Verify that the 1st entry is correct.
|
||||
entry = entries[0]
|
||||
self.assertEqual(0, entry["step"])
|
||||
parsed_query = urllib.parse.parse_qs(entry["query"])
|
||||
self.assertListEqual(["0"], parsed_query["index"])
|
||||
self.assertListEqual(["foo"], parsed_query["run"])
|
||||
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
|
||||
|
||||
# Verify that the 2nd entry is correct.
|
||||
entry = entries[1]
|
||||
self.assertEqual(1, entry["step"])
|
||||
parsed_query = urllib.parse.parse_qs(entry["query"])
|
||||
self.assertListEqual(["1"], parsed_query["index"])
|
||||
self.assertListEqual(["foo"], parsed_query["run"])
|
||||
self.assertListEqual(["baz/audio/0"], parsed_query["tag"])
|
||||
|
||||
def testIndividualAudioRoute(self):
|
||||
"""Tests fetching an individual audio."""
|
||||
response = self.server.get(
|
||||
"/data/plugin/audio/individualAudio?run=bar&tag=quux/audio/0&index=0")
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual("audio/wav", response.headers.get("content-type"))
|
||||
|
||||
def testRunsRoute(self):
|
||||
"""Tests that the /runs route offers the correct run to tag mapping."""
|
||||
response = self.server.get("/data/plugin/audio/tags")
|
||||
self.assertEqual(200, response.status_code)
|
||||
run_to_tags = self._DeserializeResponse(response.get_data())
|
||||
self.assertItemsEqual(("foo", "bar"), run_to_tags.keys())
|
||||
self.assertItemsEqual(
|
||||
["baz/audio/0", "baz/audio/1", "baz/audio/2"], run_to_tags["foo"])
|
||||
self.assertItemsEqual(
|
||||
["quux/audio/0", "quux/audio/1", "quux/audio/2"], run_to_tags["bar"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
Loading…
Reference in New Issue
Block a user