Implements SVDF model for keyword spotting tutorial.

PiperOrigin-RevId: 165725938
This commit is contained in:
A. Unique TensorFlower 2017-08-18 10:59:28 -07:00 committed by TensorFlower Gardener
parent aaabf6b902
commit 80bd004cdc
7 changed files with 257 additions and 21 deletions

View File

@ -526,14 +526,19 @@ The default model used for this script is pretty large, taking over 800 million
FLOPs for each inference and using 940,000 weight parameters. This runs at
usable speeds on desktop machines or modern phones, but it involves too many
calculations to run at interactive speeds on devices with more limited
resources. To support these use cases, there's an alternative model available,
based on the 'cnn-one-fstride4' architecture described in the [Convolutional
resources. To support these use cases, there's a couple of alternatives
available:
**low_latency_conv**
Based on the 'cnn-one-fstride4' topology described in the [Convolutional
Neural Networks for Small-footprint Keyword Spotting
paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
The number of weight parameters is about the same, but it only needs 11 million
FLOPs to run one prediction, making it much faster.
The accuracy is slightly lower than 'conv' but the number of weight parameters
is about the same, and it only needs 11 million FLOPs to run one prediction,
making it much faster.
To use this model, you can specify `--model_architecture=low_latency_conv` on
To use this model, you specify `--model_architecture=low_latency_conv` on
the command line. You'll also need to update the training rates and the number
of steps, so the full command will look like:
@ -547,6 +552,42 @@ python tensorflow/examples/speech_commands/train \
This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
**low_latency_svdf**
Based on the topology presented in the [Compressing Deep Neural Networks using a
Rank-Constrained Topology paper](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf).
The accuracy is also lower than 'conv' but it only uses about 750 thousand
parameters, and most significantly, it allows for an optimized execution at
test time (i.e. when you will actually use it in your application), resulting
in 750 thousand FLOPs.
To use this model, you specify `--model_architecture=low_latency_svdf` on
the command line, and update the training rates and the number
of steps, so the full command will look like:
```
python tensorflow/examples/speech_commands/train \
--model_architecture=low_latency_svdf \
--how_many_training_steps=100000,35000 \
--learning_rate=0.01,0.005
```
Note that despite requiring a larger number of steps than the previous two
topologies, the reduced number of computations means that training should take
about the same time, and at the end reach an accuracy of around 85%.
You can also further tune the topology fairly easily for computation and
accuracy by changing these parameters in the SVDF layer:
* rank - The rank of the approximation (higher typically better, but results in
more computation).
* num_units - Similar to other layer types, specifies the number of nodes in
the layer (more nodes better quality, and more computation).
Regarding runtime, since the layer allows optimizations by caching some of the
internal neural network activations, you need to make sure to use a consistent
stride (e.g. 'clip_stride_ms' flag) both when you freeze the graph, and when
executing the model in streaming mode (e.g. test_streaming_accuracy.cc).
**Other parameters to customize**
If you want to experiment with customizing models, a good place to start is by
tweaking the spectrogram creation parameters. This has the effect of altering
the size of the input image to the model, and the creation code in

View File

@ -53,7 +53,7 @@ FLAGS = None
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
clip_stride_ms, window_size_ms, window_stride_ms,
dct_coefficient_count, model_architecture):
"""Creates an audio model with the nodes needed for inference.
@ -64,6 +64,7 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
wanted_words: Comma-separated list of the words we're trying to recognize.
sample_rate: How many samples per second are in the input audio files.
clip_duration_ms: How many samples to analyze for the audio pattern.
clip_stride_ms: How often to run recognition. Useful for models with cache.
window_size_ms: Time slice duration to estimate frequencies from.
window_stride_ms: How far apart time slices should be.
dct_coefficient_count: Number of frequency bands to analyze.
@ -74,6 +75,7 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
model_settings = models.prepare_model_settings(
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
window_stride_ms, dct_coefficient_count)
runtime_settings = {'clip_stride_ms': clip_stride_ms}
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
decoded_sample_data = contrib_audio.decode_wav(
@ -97,7 +99,8 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
])
logits = models.create_model(
reshaped_input, model_settings, model_architecture, is_training=False)
reshaped_input, model_settings, model_architecture, is_training=False,
runtime_settings=runtime_settings)
# Create an output to use for inference.
tf.nn.softmax(logits, name='labels_softmax')
@ -108,9 +111,9 @@ def main(_):
# Create the model and load its weights.
sess = tf.InteractiveSession()
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
FLAGS.clip_duration_ms, FLAGS.window_size_ms,
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,
FLAGS.model_architecture)
FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.dct_coefficient_count, FLAGS.model_architecture)
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
@ -136,10 +139,15 @@ if __name__ == '__main__':
type=int,
default=1000,
help='Expected duration in milliseconds of the wavs',)
parser.add_argument(
'--clip_stride_ms',
type=int,
default=30,
help='How often to run recognition. Useful for models with cache.',)
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
default=30.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',

View File

@ -26,8 +26,8 @@ class FreezeTest(test.TestCase):
def testCreateInferenceGraph(self):
with self.test_session() as sess:
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 20.0, 10.0, 40,
'conv')
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
40, 'conv')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))

View File

@ -234,7 +234,7 @@ if __name__ == '__main__':
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
default=30.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',

View File

@ -62,7 +62,7 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
def create_model(fingerprint_input, model_settings, model_architecture,
is_training):
is_training, runtime_settings=None):
"""Builds a model of the requested architecture compatible with the settings.
There are many possible ways of deriving predictions from a spectrogram
@ -86,6 +86,7 @@ def create_model(fingerprint_input, model_settings, model_architecture,
model_settings: Dictionary of information about the model.
model_architecture: String specifying which kind of model to create.
is_training: Whether the model is going to be used for training.
runtime_settings: Dictionary of information about the runtime.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
@ -102,10 +103,13 @@ def create_model(fingerprint_input, model_settings, model_architecture,
elif model_architecture == 'low_latency_conv':
return create_low_latency_conv_model(fingerprint_input, model_settings,
is_training)
elif model_architecture == 'low_latency_svdf':
return create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
' or "low_latency_conv"')
' "low_latency_conv, or "low_latency_svdf"')
def load_variables_from_checkpoint(sess, start_checkpoint):
@ -376,3 +380,187 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
return final_fc, dropout_prob
else:
return final_fc
def create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings):
"""Builds an SVDF model with low compute requirements.
This is based in the topology presented in the 'Compressing Deep Neural
Networks using a Rank-Constrained Topology' paper:
https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf
Here's the layout of the graph:
(fingerprint_input)
v
[SVDF]<-(weights)
v
[BiasAdd]<-(bias)
v
[Relu]
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
This model produces lower recognition accuracy than the 'conv' model above,
but requires fewer weight parameters and, significantly fewer computations.
During training, dropout nodes are introduced after the relu, controlled by a
placeholder.
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
The node is expected to produce a 2D Tensor of shape:
[batch, model_settings['dct_coefficient_count'] *
model_settings['spectrogram_length']]
with the features corresponding to the same time slot arranged contiguously,
and the oldest slot at index [:, 0], and newest at [:, -1].
model_settings: Dictionary of information about the model.
is_training: Whether the model is going to be used for training.
runtime_settings: Dictionary of information about the runtime.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
Raises:
ValueError: If the inputs tensor is incorrectly shaped.
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_time_size = model_settings['spectrogram_length']
# Validation.
input_shape = fingerprint_input.get_shape()
if len(input_shape) != 2:
raise ValueError('Inputs to `SVDF` should have rank == 2.')
if input_shape[-1].value is None:
raise ValueError('The last dimension of the inputs to `SVDF` '
'should be defined. Found `None`.')
if input_shape[-1].value % input_frequency_size != 0:
raise ValueError('Inputs feature dimension %d must be a multiple of '
'frame size %d', fingerprint_input.shape[-1].value,
input_frequency_size)
# Set number of units (i.e. nodes) and rank.
rank = 2
num_units = 1280
# Number of filters: pairs of feature and time filters.
num_filters = rank * num_units
# Create the runtime memory: [num_filters, batch, input_time_size]
batch = 1
memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
trainable=False, name='runtime-memory')
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
# new_fingerprint_input: [batch, num_new_frames*input_frequency_size]
if is_training:
num_new_frames = input_time_size
else:
window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
model_settings['sample_rate'])
num_new_frames = tf.cond(
tf.equal(tf.count_nonzero(memory), 0),
lambda: input_time_size,
lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
new_fingerprint_input = fingerprint_input[
:, -num_new_frames*input_frequency_size:]
# Expand to add input channels dimension.
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
# Create the frequency filters.
weights_frequency = tf.Variable(
tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
# Expand to add input channels dimensions.
# weights_frequency: [input_frequency_size, 1, num_filters]
weights_frequency = tf.expand_dims(weights_frequency, 1)
# Convolve the 1D feature filters sliding over the time dimension.
# activations_time: [batch, num_new_frames, num_filters]
activations_time = tf.nn.conv1d(
new_fingerprint_input, weights_frequency, input_frequency_size, 'VALID')
# Rearrange such that we can perform the batched matmul.
# activations_time: [num_filters, batch, num_new_frames]
activations_time = tf.transpose(activations_time, perm=[2, 0, 1])
# Runtime memory optimization.
if not is_training:
# We need to drop the activations corresponding to the oldest frames, and
# then add those corresponding to the new frames.
new_memory = memory[:, :, num_new_frames:]
new_memory = tf.concat([new_memory, activations_time], 2)
tf.assign(memory, new_memory)
activations_time = new_memory
# Create the time filters.
weights_time = tf.Variable(
tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
# Apply the time filter on the outputs of the feature filters.
# weights_time: [num_filters, input_time_size, 1]
# outputs: [num_filters, batch, 1]
weights_time = tf.expand_dims(weights_time, 2)
outputs = tf.matmul(activations_time, weights_time)
# Split num_units and rank into separate dimensions (the remaining
# dimension is the input_shape[0] -i.e. batch size). This also squeezes
# the last dimension, since it's not used.
# [num_filters, batch, 1] => [num_units, rank, batch]
outputs = tf.reshape(outputs, [num_units, rank, -1])
# Sum the rank outputs per unit => [num_units, batch].
units_output = tf.reduce_sum(outputs, axis=1)
# Transpose to shape [batch, num_units]
units_output = tf.transpose(units_output)
# Appy bias.
bias = tf.Variable(tf.zeros([num_units]))
first_bias = tf.nn.bias_add(units_output, bias)
# Relu.
first_relu = tf.nn.relu(first_bias)
if is_training:
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
else:
first_dropout = first_relu
first_fc_output_channels = 256
first_fc_weights = tf.Variable(
tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 256
second_fc_weights = tf.Variable(
tf.truncated_normal(
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_fc_output_channels, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc

View File

@ -139,7 +139,7 @@ int main(int argc, char* argv[]) {
string input_rate_name = "decoded_sample_data:1";
string output_name = "labels_softmax";
int32 clip_duration_ms = 1000;
int32 sample_stride_ms = 30;
int32 clip_stride_ms = 30;
int32 average_window_ms = 500;
int32 time_tolerance_ms = 750;
int32 suppression_ms = 1500;
@ -165,8 +165,7 @@ int main(int argc, char* argv[]) {
"maximum gap allowed between a recognition and ground truth"),
Flag("suppression_ms", &suppression_ms,
"how long to ignore others for after a recognition"),
Flag("sample_stride_ms", &sample_stride_ms,
"how often to run recognition"),
Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),
Flag("detection_threshold", &detection_threshold,
"what score is required to trigger detection of a word"),
Flag("verbose", &verbose, "whether to log extra debugging information"),
@ -232,7 +231,7 @@ int main(int argc, char* argv[]) {
}
const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
const int64 sample_stride_samples = (sample_stride_ms * sample_rate) / 1000;
const int64 sample_stride_samples = (clip_stride_ms * sample_rate) / 1000;
Tensor audio_data_tensor(tensorflow::DT_FLOAT,
tensorflow::TensorShape({clip_duration_samples, 1}));

View File

@ -355,7 +355,7 @@ if __name__ == '__main__':
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
default=30.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',