mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Implements SVDF model for keyword spotting tutorial.
PiperOrigin-RevId: 165725938
This commit is contained in:
parent
aaabf6b902
commit
80bd004cdc
|
|
@ -526,14 +526,19 @@ The default model used for this script is pretty large, taking over 800 million
|
||||||
FLOPs for each inference and using 940,000 weight parameters. This runs at
|
FLOPs for each inference and using 940,000 weight parameters. This runs at
|
||||||
usable speeds on desktop machines or modern phones, but it involves too many
|
usable speeds on desktop machines or modern phones, but it involves too many
|
||||||
calculations to run at interactive speeds on devices with more limited
|
calculations to run at interactive speeds on devices with more limited
|
||||||
resources. To support these use cases, there's an alternative model available,
|
resources. To support these use cases, there's a couple of alternatives
|
||||||
based on the 'cnn-one-fstride4' architecture described in the [Convolutional
|
available:
|
||||||
|
|
||||||
|
|
||||||
|
**low_latency_conv**
|
||||||
|
Based on the 'cnn-one-fstride4' topology described in the [Convolutional
|
||||||
Neural Networks for Small-footprint Keyword Spotting
|
Neural Networks for Small-footprint Keyword Spotting
|
||||||
paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
|
paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
|
||||||
The number of weight parameters is about the same, but it only needs 11 million
|
The accuracy is slightly lower than 'conv' but the number of weight parameters
|
||||||
FLOPs to run one prediction, making it much faster.
|
is about the same, and it only needs 11 million FLOPs to run one prediction,
|
||||||
|
making it much faster.
|
||||||
|
|
||||||
To use this model, you can specify `--model_architecture=low_latency_conv` on
|
To use this model, you specify `--model_architecture=low_latency_conv` on
|
||||||
the command line. You'll also need to update the training rates and the number
|
the command line. You'll also need to update the training rates and the number
|
||||||
of steps, so the full command will look like:
|
of steps, so the full command will look like:
|
||||||
|
|
||||||
|
|
@ -547,6 +552,42 @@ python tensorflow/examples/speech_commands/train \
|
||||||
This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
|
This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
|
||||||
then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
|
then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
|
||||||
|
|
||||||
|
**low_latency_svdf**
|
||||||
|
Based on the topology presented in the [Compressing Deep Neural Networks using a
|
||||||
|
Rank-Constrained Topology paper](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf).
|
||||||
|
The accuracy is also lower than 'conv' but it only uses about 750 thousand
|
||||||
|
parameters, and most significantly, it allows for an optimized execution at
|
||||||
|
test time (i.e. when you will actually use it in your application), resulting
|
||||||
|
in 750 thousand FLOPs.
|
||||||
|
|
||||||
|
To use this model, you specify `--model_architecture=low_latency_svdf` on
|
||||||
|
the command line, and update the training rates and the number
|
||||||
|
of steps, so the full command will look like:
|
||||||
|
|
||||||
|
```
|
||||||
|
python tensorflow/examples/speech_commands/train \
|
||||||
|
--model_architecture=low_latency_svdf \
|
||||||
|
--how_many_training_steps=100000,35000 \
|
||||||
|
--learning_rate=0.01,0.005
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that despite requiring a larger number of steps than the previous two
|
||||||
|
topologies, the reduced number of computations means that training should take
|
||||||
|
about the same time, and at the end reach an accuracy of around 85%.
|
||||||
|
You can also further tune the topology fairly easily for computation and
|
||||||
|
accuracy by changing these parameters in the SVDF layer:
|
||||||
|
|
||||||
|
* rank - The rank of the approximation (higher typically better, but results in
|
||||||
|
more computation).
|
||||||
|
* num_units - Similar to other layer types, specifies the number of nodes in
|
||||||
|
the layer (more nodes better quality, and more computation).
|
||||||
|
|
||||||
|
Regarding runtime, since the layer allows optimizations by caching some of the
|
||||||
|
internal neural network activations, you need to make sure to use a consistent
|
||||||
|
stride (e.g. 'clip_stride_ms' flag) both when you freeze the graph, and when
|
||||||
|
executing the model in streaming mode (e.g. test_streaming_accuracy.cc).
|
||||||
|
|
||||||
|
**Other parameters to customize**
|
||||||
If you want to experiment with customizing models, a good place to start is by
|
If you want to experiment with customizing models, a good place to start is by
|
||||||
tweaking the spectrogram creation parameters. This has the effect of altering
|
tweaking the spectrogram creation parameters. This has the effect of altering
|
||||||
the size of the input image to the model, and the creation code in
|
the size of the input image to the model, and the creation code in
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||||
window_size_ms, window_stride_ms,
|
clip_stride_ms, window_size_ms, window_stride_ms,
|
||||||
dct_coefficient_count, model_architecture):
|
dct_coefficient_count, model_architecture):
|
||||||
"""Creates an audio model with the nodes needed for inference.
|
"""Creates an audio model with the nodes needed for inference.
|
||||||
|
|
||||||
|
|
@ -64,6 +64,7 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||||
wanted_words: Comma-separated list of the words we're trying to recognize.
|
wanted_words: Comma-separated list of the words we're trying to recognize.
|
||||||
sample_rate: How many samples per second are in the input audio files.
|
sample_rate: How many samples per second are in the input audio files.
|
||||||
clip_duration_ms: How many samples to analyze for the audio pattern.
|
clip_duration_ms: How many samples to analyze for the audio pattern.
|
||||||
|
clip_stride_ms: How often to run recognition. Useful for models with cache.
|
||||||
window_size_ms: Time slice duration to estimate frequencies from.
|
window_size_ms: Time slice duration to estimate frequencies from.
|
||||||
window_stride_ms: How far apart time slices should be.
|
window_stride_ms: How far apart time slices should be.
|
||||||
dct_coefficient_count: Number of frequency bands to analyze.
|
dct_coefficient_count: Number of frequency bands to analyze.
|
||||||
|
|
@ -74,6 +75,7 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||||
model_settings = models.prepare_model_settings(
|
model_settings = models.prepare_model_settings(
|
||||||
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
|
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
|
||||||
window_stride_ms, dct_coefficient_count)
|
window_stride_ms, dct_coefficient_count)
|
||||||
|
runtime_settings = {'clip_stride_ms': clip_stride_ms}
|
||||||
|
|
||||||
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
|
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
|
||||||
decoded_sample_data = contrib_audio.decode_wav(
|
decoded_sample_data = contrib_audio.decode_wav(
|
||||||
|
|
@ -97,7 +99,8 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||||
])
|
])
|
||||||
|
|
||||||
logits = models.create_model(
|
logits = models.create_model(
|
||||||
reshaped_input, model_settings, model_architecture, is_training=False)
|
reshaped_input, model_settings, model_architecture, is_training=False,
|
||||||
|
runtime_settings=runtime_settings)
|
||||||
|
|
||||||
# Create an output to use for inference.
|
# Create an output to use for inference.
|
||||||
tf.nn.softmax(logits, name='labels_softmax')
|
tf.nn.softmax(logits, name='labels_softmax')
|
||||||
|
|
@ -108,9 +111,9 @@ def main(_):
|
||||||
# Create the model and load its weights.
|
# Create the model and load its weights.
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
|
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
|
||||||
FLAGS.clip_duration_ms, FLAGS.window_size_ms,
|
FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
|
||||||
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,
|
FLAGS.window_size_ms, FLAGS.window_stride_ms,
|
||||||
FLAGS.model_architecture)
|
FLAGS.dct_coefficient_count, FLAGS.model_architecture)
|
||||||
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
|
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
|
||||||
|
|
||||||
# Turn all the variables into inline constants inside the graph and save it.
|
# Turn all the variables into inline constants inside the graph and save it.
|
||||||
|
|
@ -136,10 +139,15 @@ if __name__ == '__main__':
|
||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=1000,
|
||||||
help='Expected duration in milliseconds of the wavs',)
|
help='Expected duration in milliseconds of the wavs',)
|
||||||
|
parser.add_argument(
|
||||||
|
'--clip_stride_ms',
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help='How often to run recognition. Useful for models with cache.',)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_size_ms',
|
'--window_size_ms',
|
||||||
type=float,
|
type=float,
|
||||||
default=20.0,
|
default=30.0,
|
||||||
help='How long each spectrogram timeslice is',)
|
help='How long each spectrogram timeslice is',)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_stride_ms',
|
'--window_stride_ms',
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,8 @@ class FreezeTest(test.TestCase):
|
||||||
|
|
||||||
def testCreateInferenceGraph(self):
|
def testCreateInferenceGraph(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 20.0, 10.0, 40,
|
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
|
||||||
'conv')
|
40, 'conv')
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
|
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
|
||||||
self.assertIsNotNone(
|
self.assertIsNotNone(
|
||||||
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
|
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
|
||||||
|
|
|
||||||
|
|
@ -234,7 +234,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_size_ms',
|
'--window_size_ms',
|
||||||
type=float,
|
type=float,
|
||||||
default=20.0,
|
default=30.0,
|
||||||
help='How long each spectrogram timeslice is',)
|
help='How long each spectrogram timeslice is',)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_stride_ms',
|
'--window_stride_ms',
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
|
||||||
|
|
||||||
|
|
||||||
def create_model(fingerprint_input, model_settings, model_architecture,
|
def create_model(fingerprint_input, model_settings, model_architecture,
|
||||||
is_training):
|
is_training, runtime_settings=None):
|
||||||
"""Builds a model of the requested architecture compatible with the settings.
|
"""Builds a model of the requested architecture compatible with the settings.
|
||||||
|
|
||||||
There are many possible ways of deriving predictions from a spectrogram
|
There are many possible ways of deriving predictions from a spectrogram
|
||||||
|
|
@ -86,6 +86,7 @@ def create_model(fingerprint_input, model_settings, model_architecture,
|
||||||
model_settings: Dictionary of information about the model.
|
model_settings: Dictionary of information about the model.
|
||||||
model_architecture: String specifying which kind of model to create.
|
model_architecture: String specifying which kind of model to create.
|
||||||
is_training: Whether the model is going to be used for training.
|
is_training: Whether the model is going to be used for training.
|
||||||
|
runtime_settings: Dictionary of information about the runtime.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TensorFlow node outputting logits results, and optionally a dropout
|
TensorFlow node outputting logits results, and optionally a dropout
|
||||||
|
|
@ -102,10 +103,13 @@ def create_model(fingerprint_input, model_settings, model_architecture,
|
||||||
elif model_architecture == 'low_latency_conv':
|
elif model_architecture == 'low_latency_conv':
|
||||||
return create_low_latency_conv_model(fingerprint_input, model_settings,
|
return create_low_latency_conv_model(fingerprint_input, model_settings,
|
||||||
is_training)
|
is_training)
|
||||||
|
elif model_architecture == 'low_latency_svdf':
|
||||||
|
return create_low_latency_svdf_model(fingerprint_input, model_settings,
|
||||||
|
is_training, runtime_settings)
|
||||||
else:
|
else:
|
||||||
raise Exception('model_architecture argument "' + model_architecture +
|
raise Exception('model_architecture argument "' + model_architecture +
|
||||||
'" not recognized, should be one of "single_fc", "conv",' +
|
'" not recognized, should be one of "single_fc", "conv",' +
|
||||||
' or "low_latency_conv"')
|
' "low_latency_conv, or "low_latency_svdf"')
|
||||||
|
|
||||||
|
|
||||||
def load_variables_from_checkpoint(sess, start_checkpoint):
|
def load_variables_from_checkpoint(sess, start_checkpoint):
|
||||||
|
|
@ -376,3 +380,187 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
|
||||||
return final_fc, dropout_prob
|
return final_fc, dropout_prob
|
||||||
else:
|
else:
|
||||||
return final_fc
|
return final_fc
|
||||||
|
|
||||||
|
|
||||||
|
def create_low_latency_svdf_model(fingerprint_input, model_settings,
|
||||||
|
is_training, runtime_settings):
|
||||||
|
"""Builds an SVDF model with low compute requirements.
|
||||||
|
|
||||||
|
This is based in the topology presented in the 'Compressing Deep Neural
|
||||||
|
Networks using a Rank-Constrained Topology' paper:
|
||||||
|
https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf
|
||||||
|
|
||||||
|
Here's the layout of the graph:
|
||||||
|
|
||||||
|
(fingerprint_input)
|
||||||
|
v
|
||||||
|
[SVDF]<-(weights)
|
||||||
|
v
|
||||||
|
[BiasAdd]<-(bias)
|
||||||
|
v
|
||||||
|
[Relu]
|
||||||
|
v
|
||||||
|
[MatMul]<-(weights)
|
||||||
|
v
|
||||||
|
[BiasAdd]<-(bias)
|
||||||
|
v
|
||||||
|
[MatMul]<-(weights)
|
||||||
|
v
|
||||||
|
[BiasAdd]<-(bias)
|
||||||
|
v
|
||||||
|
[MatMul]<-(weights)
|
||||||
|
v
|
||||||
|
[BiasAdd]<-(bias)
|
||||||
|
v
|
||||||
|
|
||||||
|
This model produces lower recognition accuracy than the 'conv' model above,
|
||||||
|
but requires fewer weight parameters and, significantly fewer computations.
|
||||||
|
|
||||||
|
During training, dropout nodes are introduced after the relu, controlled by a
|
||||||
|
placeholder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint_input: TensorFlow node that will output audio feature vectors.
|
||||||
|
The node is expected to produce a 2D Tensor of shape:
|
||||||
|
[batch, model_settings['dct_coefficient_count'] *
|
||||||
|
model_settings['spectrogram_length']]
|
||||||
|
with the features corresponding to the same time slot arranged contiguously,
|
||||||
|
and the oldest slot at index [:, 0], and newest at [:, -1].
|
||||||
|
model_settings: Dictionary of information about the model.
|
||||||
|
is_training: Whether the model is going to be used for training.
|
||||||
|
runtime_settings: Dictionary of information about the runtime.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TensorFlow node outputting logits results, and optionally a dropout
|
||||||
|
placeholder.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the inputs tensor is incorrectly shaped.
|
||||||
|
"""
|
||||||
|
if is_training:
|
||||||
|
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
|
||||||
|
|
||||||
|
input_frequency_size = model_settings['dct_coefficient_count']
|
||||||
|
input_time_size = model_settings['spectrogram_length']
|
||||||
|
|
||||||
|
# Validation.
|
||||||
|
input_shape = fingerprint_input.get_shape()
|
||||||
|
if len(input_shape) != 2:
|
||||||
|
raise ValueError('Inputs to `SVDF` should have rank == 2.')
|
||||||
|
if input_shape[-1].value is None:
|
||||||
|
raise ValueError('The last dimension of the inputs to `SVDF` '
|
||||||
|
'should be defined. Found `None`.')
|
||||||
|
if input_shape[-1].value % input_frequency_size != 0:
|
||||||
|
raise ValueError('Inputs feature dimension %d must be a multiple of '
|
||||||
|
'frame size %d', fingerprint_input.shape[-1].value,
|
||||||
|
input_frequency_size)
|
||||||
|
|
||||||
|
# Set number of units (i.e. nodes) and rank.
|
||||||
|
rank = 2
|
||||||
|
num_units = 1280
|
||||||
|
# Number of filters: pairs of feature and time filters.
|
||||||
|
num_filters = rank * num_units
|
||||||
|
# Create the runtime memory: [num_filters, batch, input_time_size]
|
||||||
|
batch = 1
|
||||||
|
memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
|
||||||
|
trainable=False, name='runtime-memory')
|
||||||
|
# Determine the number of new frames in the input, such that we only operate
|
||||||
|
# on those. For training we do not use the memory, and thus use all frames
|
||||||
|
# provided in the input.
|
||||||
|
# new_fingerprint_input: [batch, num_new_frames*input_frequency_size]
|
||||||
|
if is_training:
|
||||||
|
num_new_frames = input_time_size
|
||||||
|
else:
|
||||||
|
window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
|
||||||
|
model_settings['sample_rate'])
|
||||||
|
num_new_frames = tf.cond(
|
||||||
|
tf.equal(tf.count_nonzero(memory), 0),
|
||||||
|
lambda: input_time_size,
|
||||||
|
lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
|
||||||
|
new_fingerprint_input = fingerprint_input[
|
||||||
|
:, -num_new_frames*input_frequency_size:]
|
||||||
|
# Expand to add input channels dimension.
|
||||||
|
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
|
||||||
|
|
||||||
|
# Create the frequency filters.
|
||||||
|
weights_frequency = tf.Variable(
|
||||||
|
tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
|
||||||
|
# Expand to add input channels dimensions.
|
||||||
|
# weights_frequency: [input_frequency_size, 1, num_filters]
|
||||||
|
weights_frequency = tf.expand_dims(weights_frequency, 1)
|
||||||
|
# Convolve the 1D feature filters sliding over the time dimension.
|
||||||
|
# activations_time: [batch, num_new_frames, num_filters]
|
||||||
|
activations_time = tf.nn.conv1d(
|
||||||
|
new_fingerprint_input, weights_frequency, input_frequency_size, 'VALID')
|
||||||
|
# Rearrange such that we can perform the batched matmul.
|
||||||
|
# activations_time: [num_filters, batch, num_new_frames]
|
||||||
|
activations_time = tf.transpose(activations_time, perm=[2, 0, 1])
|
||||||
|
|
||||||
|
# Runtime memory optimization.
|
||||||
|
if not is_training:
|
||||||
|
# We need to drop the activations corresponding to the oldest frames, and
|
||||||
|
# then add those corresponding to the new frames.
|
||||||
|
new_memory = memory[:, :, num_new_frames:]
|
||||||
|
new_memory = tf.concat([new_memory, activations_time], 2)
|
||||||
|
tf.assign(memory, new_memory)
|
||||||
|
activations_time = new_memory
|
||||||
|
|
||||||
|
# Create the time filters.
|
||||||
|
weights_time = tf.Variable(
|
||||||
|
tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
|
||||||
|
# Apply the time filter on the outputs of the feature filters.
|
||||||
|
# weights_time: [num_filters, input_time_size, 1]
|
||||||
|
# outputs: [num_filters, batch, 1]
|
||||||
|
weights_time = tf.expand_dims(weights_time, 2)
|
||||||
|
outputs = tf.matmul(activations_time, weights_time)
|
||||||
|
# Split num_units and rank into separate dimensions (the remaining
|
||||||
|
# dimension is the input_shape[0] -i.e. batch size). This also squeezes
|
||||||
|
# the last dimension, since it's not used.
|
||||||
|
# [num_filters, batch, 1] => [num_units, rank, batch]
|
||||||
|
outputs = tf.reshape(outputs, [num_units, rank, -1])
|
||||||
|
# Sum the rank outputs per unit => [num_units, batch].
|
||||||
|
units_output = tf.reduce_sum(outputs, axis=1)
|
||||||
|
# Transpose to shape [batch, num_units]
|
||||||
|
units_output = tf.transpose(units_output)
|
||||||
|
|
||||||
|
# Appy bias.
|
||||||
|
bias = tf.Variable(tf.zeros([num_units]))
|
||||||
|
first_bias = tf.nn.bias_add(units_output, bias)
|
||||||
|
|
||||||
|
# Relu.
|
||||||
|
first_relu = tf.nn.relu(first_bias)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
|
||||||
|
else:
|
||||||
|
first_dropout = first_relu
|
||||||
|
|
||||||
|
first_fc_output_channels = 256
|
||||||
|
first_fc_weights = tf.Variable(
|
||||||
|
tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
|
||||||
|
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
|
||||||
|
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
|
||||||
|
if is_training:
|
||||||
|
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
|
||||||
|
else:
|
||||||
|
second_fc_input = first_fc
|
||||||
|
second_fc_output_channels = 256
|
||||||
|
second_fc_weights = tf.Variable(
|
||||||
|
tf.truncated_normal(
|
||||||
|
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
|
||||||
|
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
|
||||||
|
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
|
||||||
|
if is_training:
|
||||||
|
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
|
||||||
|
else:
|
||||||
|
final_fc_input = second_fc
|
||||||
|
label_count = model_settings['label_count']
|
||||||
|
final_fc_weights = tf.Variable(
|
||||||
|
tf.truncated_normal(
|
||||||
|
[second_fc_output_channels, label_count], stddev=0.01))
|
||||||
|
final_fc_bias = tf.Variable(tf.zeros([label_count]))
|
||||||
|
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
|
||||||
|
if is_training:
|
||||||
|
return final_fc, dropout_prob
|
||||||
|
else:
|
||||||
|
return final_fc
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,7 @@ int main(int argc, char* argv[]) {
|
||||||
string input_rate_name = "decoded_sample_data:1";
|
string input_rate_name = "decoded_sample_data:1";
|
||||||
string output_name = "labels_softmax";
|
string output_name = "labels_softmax";
|
||||||
int32 clip_duration_ms = 1000;
|
int32 clip_duration_ms = 1000;
|
||||||
int32 sample_stride_ms = 30;
|
int32 clip_stride_ms = 30;
|
||||||
int32 average_window_ms = 500;
|
int32 average_window_ms = 500;
|
||||||
int32 time_tolerance_ms = 750;
|
int32 time_tolerance_ms = 750;
|
||||||
int32 suppression_ms = 1500;
|
int32 suppression_ms = 1500;
|
||||||
|
|
@ -165,8 +165,7 @@ int main(int argc, char* argv[]) {
|
||||||
"maximum gap allowed between a recognition and ground truth"),
|
"maximum gap allowed between a recognition and ground truth"),
|
||||||
Flag("suppression_ms", &suppression_ms,
|
Flag("suppression_ms", &suppression_ms,
|
||||||
"how long to ignore others for after a recognition"),
|
"how long to ignore others for after a recognition"),
|
||||||
Flag("sample_stride_ms", &sample_stride_ms,
|
Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),
|
||||||
"how often to run recognition"),
|
|
||||||
Flag("detection_threshold", &detection_threshold,
|
Flag("detection_threshold", &detection_threshold,
|
||||||
"what score is required to trigger detection of a word"),
|
"what score is required to trigger detection of a word"),
|
||||||
Flag("verbose", &verbose, "whether to log extra debugging information"),
|
Flag("verbose", &verbose, "whether to log extra debugging information"),
|
||||||
|
|
@ -232,7 +231,7 @@ int main(int argc, char* argv[]) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
|
const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
|
||||||
const int64 sample_stride_samples = (sample_stride_ms * sample_rate) / 1000;
|
const int64 sample_stride_samples = (clip_stride_ms * sample_rate) / 1000;
|
||||||
Tensor audio_data_tensor(tensorflow::DT_FLOAT,
|
Tensor audio_data_tensor(tensorflow::DT_FLOAT,
|
||||||
tensorflow::TensorShape({clip_duration_samples, 1}));
|
tensorflow::TensorShape({clip_duration_samples, 1}));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -355,7 +355,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_size_ms',
|
'--window_size_ms',
|
||||||
type=float,
|
type=float,
|
||||||
default=20.0,
|
default=30.0,
|
||||||
help='How long each spectrogram timeslice is',)
|
help='How long each spectrogram timeslice is',)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--window_stride_ms',
|
'--window_stride_ms',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user