mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[tf-signal] Use tf.spectral.dct in mfccs_from_log_mel_spectrograms instead of a private implementation.
PiperOrigin-RevId: 170943986
This commit is contained in:
parent
b959da92f9
commit
add6d2d03c
|
|
@ -18,75 +18,12 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import mfcc_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import spectral_ops_test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
# TODO(rjryan): Add scipy.fftpack to the TensorFlow build.
|
||||
def try_import(name): # pylint: disable=invalid-name
|
||||
module = None
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except ImportError as e:
|
||||
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
|
||||
return module
|
||||
|
||||
|
||||
fftpack = try_import("scipy.fftpack")
|
||||
|
||||
|
||||
class DCTTest(test.TestCase):
|
||||
|
||||
def _np_dct2(self, signals, norm=None):
|
||||
"""Computes the DCT-II manually with NumPy."""
|
||||
# X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
|
||||
dct_size = signals.shape[-1]
|
||||
dct = np.zeros_like(signals)
|
||||
for k in range(dct_size):
|
||||
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
|
||||
dct[..., k] = np.sum(signals * phi, axis=-1)
|
||||
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
|
||||
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
|
||||
if norm == "ortho":
|
||||
# The orthogonal scaling includes a factor of 0.5 which we combine with
|
||||
# the overall scaling of 2.0 to cancel.
|
||||
dct[..., 0] *= np.sqrt(1.0 / dct_size)
|
||||
dct[..., 1:] *= np.sqrt(2.0 / dct_size)
|
||||
else:
|
||||
dct *= 2.0
|
||||
return dct
|
||||
|
||||
def test_compare_to_numpy(self):
|
||||
"""Compare dct against a manual DCT-II implementation."""
|
||||
with spectral_ops_test_util.fft_kernel_label_map():
|
||||
with self.test_session(use_gpu=True):
|
||||
for size in range(1, 23):
|
||||
signals = np.random.rand(size).astype(np.float32)
|
||||
actual_dct = mfcc_ops._dct2_1d(signals).eval()
|
||||
expected_dct = self._np_dct2(signals)
|
||||
self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
|
||||
|
||||
def test_compare_to_fftpack(self):
|
||||
"""Compare dct against scipy.fftpack.dct."""
|
||||
if not fftpack:
|
||||
return
|
||||
with spectral_ops_test_util.fft_kernel_label_map():
|
||||
with self.test_session(use_gpu=True):
|
||||
for size in range(1, 23):
|
||||
signal = np.random.rand(size).astype(np.float32)
|
||||
actual_dct = mfcc_ops._dct2_1d(signal).eval()
|
||||
expected_dct = fftpack.dct(signal, type=2)
|
||||
self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
|
||||
|
||||
|
||||
# TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
|
@ -27,35 +25,6 @@ from tensorflow.python.ops import math_ops
|
|||
from tensorflow.python.ops import spectral_ops
|
||||
|
||||
|
||||
# TODO(rjryan): Remove once tf.spectral.dct exists.
|
||||
def _dct2_1d(signals, name=None):
|
||||
"""Computes the type II 1D Discrete Cosine Transform (DCT) of `signals`.
|
||||
|
||||
Args:
|
||||
signals: A `[..., samples]` `float32` `Tensor` containing the signals to
|
||||
take the DCT of.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `[..., samples]` `float32` `Tensor` containing the DCT of `signals`.
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, 'dct', [signals]):
|
||||
# We use the FFT to compute the DCT and TensorFlow only supports float32 for
|
||||
# FFTs at the moment.
|
||||
signals = ops.convert_to_tensor(signals, dtype=dtypes.float32)
|
||||
|
||||
axis_dim = signals.shape[-1].value or array_ops.shape(signals)[-1]
|
||||
axis_dim_float = math_ops.to_float(axis_dim)
|
||||
scale = 2.0 * math_ops.exp(math_ops.complex(
|
||||
0.0, -math.pi * math_ops.range(axis_dim_float) /
|
||||
(2.0 * axis_dim_float)))
|
||||
|
||||
rfft = spectral_ops.rfft(signals, fft_length=[2 * axis_dim])[..., :axis_dim]
|
||||
dct2 = math_ops.real(rfft * scale)
|
||||
return dct2
|
||||
|
||||
|
||||
def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
|
||||
"""Computes [MFCCs][mfcc] of `log_mel_spectrograms`.
|
||||
|
||||
|
|
@ -134,4 +103,6 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
|
|||
log_mel_spectrograms)
|
||||
else:
|
||||
num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1]
|
||||
return _dct2_1d(log_mel_spectrograms) * math_ops.rsqrt(num_mel_bins * 2.0)
|
||||
|
||||
dct2 = spectral_ops.dct(log_mel_spectrograms)
|
||||
return dct2 * math_ops.rsqrt(num_mel_bins * 2.0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user