mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fixed input shape for freezing audio graphs
PiperOrigin-RevId: 165649546
This commit is contained in:
parent
9b9e5989d2
commit
a3c4e980e0
|
|
@ -90,9 +90,14 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
decoded_sample_data.sample_rate,
|
decoded_sample_data.sample_rate,
|
||||||
dct_coefficient_count=dct_coefficient_count)
|
dct_coefficient_count=dct_coefficient_count)
|
||||||
|
fingerprint_frequency_size = model_settings['dct_coefficient_count']
|
||||||
|
fingerprint_time_size = model_settings['spectrogram_length']
|
||||||
|
reshaped_input = tf.reshape(fingerprint_input, [
|
||||||
|
-1, fingerprint_time_size * fingerprint_frequency_size
|
||||||
|
])
|
||||||
|
|
||||||
logits = models.create_model(
|
logits = models.create_model(
|
||||||
fingerprint_input, model_settings, model_architecture, is_training=False)
|
reshaped_input, model_settings, model_architecture, is_training=False)
|
||||||
|
|
||||||
# Create an output to use for inference.
|
# Create an output to use for inference.
|
||||||
tf.nn.softmax(logits, name='labels_softmax')
|
tf.nn.softmax(logits, name='labels_softmax')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user