[tf-signal] Use tf.spectral.dct in mfccs_from_log_mel_spectrograms instead of a private implementation.

PiperOrigin-RevId: 170943986
This commit is contained in:
RJ Ryan 2017-10-03 17:50:55 -07:00 committed by TensorFlower Gardener
parent b959da92f9
commit add6d2d03c
2 changed files with 3 additions and 95 deletions

View File

@ -18,75 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib
import numpy as np
from tensorflow.contrib.signal.python.ops import mfcc_ops from tensorflow.contrib.signal.python.ops import mfcc_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.ops import spectral_ops_test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
# TODO(rjryan): Add scipy.fftpack to the TensorFlow build.
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
fftpack = try_import("scipy.fftpack")
class DCTTest(test.TestCase):
def _np_dct2(self, signals, norm=None):
"""Computes the DCT-II manually with NumPy."""
# X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
dct_size = signals.shape[-1]
dct = np.zeros_like(signals)
for k in range(dct_size):
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
dct[..., k] = np.sum(signals * phi, axis=-1)
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
if norm == "ortho":
# The orthogonal scaling includes a factor of 0.5 which we combine with
# the overall scaling of 2.0 to cancel.
dct[..., 0] *= np.sqrt(1.0 / dct_size)
dct[..., 1:] *= np.sqrt(2.0 / dct_size)
else:
dct *= 2.0
return dct
def test_compare_to_numpy(self):
"""Compare dct against a manual DCT-II implementation."""
with spectral_ops_test_util.fft_kernel_label_map():
with self.test_session(use_gpu=True):
for size in range(1, 23):
signals = np.random.rand(size).astype(np.float32)
actual_dct = mfcc_ops._dct2_1d(signals).eval()
expected_dct = self._np_dct2(signals)
self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
def test_compare_to_fftpack(self):
"""Compare dct against scipy.fftpack.dct."""
if not fftpack:
return
with spectral_ops_test_util.fft_kernel_label_map():
with self.test_session(use_gpu=True):
for size in range(1, 23):
signal = np.random.rand(size).astype(np.float32)
actual_dct = mfcc_ops._dct2_1d(signal).eval()
expected_dct = fftpack.dct(signal, type=2)
self.assertAllClose(expected_dct, actual_dct, atol=5e-4, rtol=5e-4)
# TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally # TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -27,35 +25,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops from tensorflow.python.ops import spectral_ops
# TODO(rjryan): Remove once tf.spectral.dct exists.
def _dct2_1d(signals, name=None):
"""Computes the type II 1D Discrete Cosine Transform (DCT) of `signals`.
Args:
signals: A `[..., samples]` `float32` `Tensor` containing the signals to
take the DCT of.
name: An optional name for the operation.
Returns:
A `[..., samples]` `float32` `Tensor` containing the DCT of `signals`.
"""
with ops.name_scope(name, 'dct', [signals]):
# We use the FFT to compute the DCT and TensorFlow only supports float32 for
# FFTs at the moment.
signals = ops.convert_to_tensor(signals, dtype=dtypes.float32)
axis_dim = signals.shape[-1].value or array_ops.shape(signals)[-1]
axis_dim_float = math_ops.to_float(axis_dim)
scale = 2.0 * math_ops.exp(math_ops.complex(
0.0, -math.pi * math_ops.range(axis_dim_float) /
(2.0 * axis_dim_float)))
rfft = spectral_ops.rfft(signals, fft_length=[2 * axis_dim])[..., :axis_dim]
dct2 = math_ops.real(rfft * scale)
return dct2
def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
"""Computes [MFCCs][mfcc] of `log_mel_spectrograms`. """Computes [MFCCs][mfcc] of `log_mel_spectrograms`.
@ -134,4 +103,6 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
log_mel_spectrograms) log_mel_spectrograms)
else: else:
num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1]
return _dct2_1d(log_mel_spectrograms) * math_ops.rsqrt(num_mel_bins * 2.0)
dct2 = spectral_ops.dct(log_mel_spectrograms)
return dct2 * math_ops.rsqrt(num_mel_bins * 2.0)