mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix FFT documentation examples and run doctests in the test suite (#60304)
Summary: Fixes https://github.com/pytorch/pytorch/issues/59514 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60304 Reviewed By: anjali411 Differential Revision: D29253980 Pulled By: mruberry fbshipit-source-id: 0654f00197e5fae338aa8edf0b61ef5692cdaa7e
This commit is contained in:
parent
5921b5480a
commit
5d476f5b95
|
|
@ -4,6 +4,8 @@ import math
|
|||
from contextlib import contextmanager
|
||||
from itertools import product
|
||||
import itertools
|
||||
import doctest
|
||||
import inspect
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
|
||||
|
|
@ -1218,7 +1220,55 @@ class TestFFT(TestCase):
|
|||
torch.istft(x.to(device), n_fft=100, window=window)
|
||||
|
||||
|
||||
class FFTDocTestFinder:
|
||||
'''The default doctest finder doesn't like that function.__module__ doesn't
|
||||
match torch.fft. It assumes the functions are leaked imports.
|
||||
'''
|
||||
def __init__(self):
|
||||
self.parser = doctest.DocTestParser()
|
||||
|
||||
def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
|
||||
doctests = []
|
||||
|
||||
modname = name if name is not None else obj.__name__
|
||||
globs = dict() if globs is None else globs
|
||||
|
||||
for fname in obj.__all__:
|
||||
func = getattr(obj, fname)
|
||||
if inspect.isroutine(func):
|
||||
qualname = modname + '.' + fname
|
||||
docstring = inspect.getdoc(func)
|
||||
if docstring is None:
|
||||
continue
|
||||
|
||||
examples = self.parser.get_doctest(
|
||||
docstring, globs=globs, name=fname, filename=None, lineno=None)
|
||||
doctests.append(examples)
|
||||
|
||||
return doctests
|
||||
|
||||
|
||||
class TestFFTDocExamples(TestCase):
|
||||
pass
|
||||
|
||||
def generate_doc_test(doc_test):
|
||||
def test(self, device):
|
||||
self.assertEqual(device, 'cpu')
|
||||
runner = doctest.DocTestRunner()
|
||||
runner.run(doc_test)
|
||||
|
||||
if runner.failures != 0:
|
||||
runner.summarize()
|
||||
self.fail('Doctest failed')
|
||||
|
||||
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoMkl(test))
|
||||
|
||||
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
|
||||
generate_doc_test(doc_test)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFFT, globals())
|
||||
instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ Example:
|
|||
>>> torch.fft.fft(t)
|
||||
tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
|
||||
|
||||
>>> t = tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
|
||||
>>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
|
||||
>>> torch.fft.fft(t)
|
||||
tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
|
||||
""".format(**common_args))
|
||||
|
|
@ -147,7 +147,7 @@ Example:
|
|||
here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
|
||||
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
|
||||
>>> torch.allclose(fft2, two_ffts)
|
||||
>>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -193,7 +193,7 @@ Example:
|
|||
here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
|
||||
|
||||
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
|
||||
>>> torch.allclose(ifft2, two_iffts)
|
||||
>>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -247,7 +247,7 @@ Example:
|
|||
here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
|
||||
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
|
||||
>>> torch.allclose(fftn, two_ffts)
|
||||
>>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -292,7 +292,7 @@ Example:
|
|||
here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
|
||||
|
||||
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
|
||||
>>> torch.allclose(ifftn, two_iffts)
|
||||
>>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -393,23 +393,24 @@ Keyword args:
|
|||
|
||||
Example:
|
||||
|
||||
>>> t = torch.arange(5)
|
||||
>>> t = torch.linspace(0, 1, 5)
|
||||
>>> t
|
||||
tensor([0, 1, 2, 3, 4])
|
||||
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
|
||||
>>> T = torch.fft.rfft(t)
|
||||
>>> T
|
||||
tensor([10.0000+0.0000j, -2.5000+3.4410j, -2.5000+0.8123j])
|
||||
tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j])
|
||||
|
||||
Without specifying the output length to :func:`~torch.fft.irfft`, the output
|
||||
will not round-trip properly because the input is odd-length:
|
||||
|
||||
>>> torch.fft.irfft(T)
|
||||
tensor([0.6250, 1.4045, 3.1250, 4.8455])
|
||||
tensor([0.1562, 0.3511, 0.7812, 1.2114])
|
||||
|
||||
So, it is recommended to always pass the signal length :attr:`n`:
|
||||
|
||||
>>> torch.fft.irfft(T, t.numel())
|
||||
tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000])
|
||||
>>> roundtrip = torch.fft.irfft(T, t.numel())
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
rfft2 = _add_docstr(_fft.fft_rfft2, r"""
|
||||
|
|
@ -461,15 +462,14 @@ Example:
|
|||
elements up to the Nyquist frequency.
|
||||
|
||||
>>> fft2 = torch.fft.fft2(t)
|
||||
>>> torch.allclose(fft2[..., :6], rfft2)
|
||||
True
|
||||
>>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False)
|
||||
|
||||
The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2`
|
||||
here is equivalent to a combination of :func:`~torch.fft.fft` and
|
||||
:func:`~torch.fft.rfft`:
|
||||
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
|
||||
>>> torch.allclose(rfft2, two_ffts)
|
||||
>>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -535,15 +535,14 @@ Example:
|
|||
dimension:
|
||||
|
||||
>>> torch.fft.irfft2(T).size()
|
||||
torch.Size([10, 10])
|
||||
torch.Size([10, 8])
|
||||
|
||||
So, it is recommended to always pass the signal shape :attr:`s`.
|
||||
|
||||
>>> roundtrip = torch.fft.irfft2(T, t.size())
|
||||
>>> roundtrip.size()
|
||||
torch.Size([10, 9])
|
||||
>>> torch.allclose(roundtrip, t)
|
||||
True
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -596,15 +595,14 @@ Example:
|
|||
elements up to the Nyquist frequency.
|
||||
|
||||
>>> fftn = torch.fft.fftn(t)
|
||||
>>> torch.allclose(fftn[..., :6], rfftn)
|
||||
True
|
||||
>>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False)
|
||||
|
||||
The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn`
|
||||
here is equivalent to a combination of :func:`~torch.fft.fft` and
|
||||
:func:`~torch.fft.rfft`:
|
||||
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
|
||||
>>> torch.allclose(rfftn, two_ffts)
|
||||
>>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -669,15 +667,14 @@ Example:
|
|||
dimension:
|
||||
|
||||
>>> torch.fft.irfftn(T).size()
|
||||
torch.Size([10, 10])
|
||||
torch.Size([10, 8])
|
||||
|
||||
So, it is recommended to always pass the signal shape :attr:`s`.
|
||||
|
||||
>>> roundtrip = torch.fft.irfftn(T, t.size())
|
||||
>>> roundtrip.size()
|
||||
torch.Size([10, 9])
|
||||
>>> torch.allclose(roundtrip, t)
|
||||
True
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
|
|
@ -741,26 +738,26 @@ Example:
|
|||
Taking a real-valued frequency signal and bringing it into the time domain
|
||||
gives Hermitian symmetric output:
|
||||
|
||||
>>> t = torch.arange(5)
|
||||
>>> t = torch.linspace(0, 1, 5)
|
||||
>>> t
|
||||
tensor([0, 1, 2, 3, 4])
|
||||
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
|
||||
>>> T = torch.fft.ifft(t)
|
||||
>>> T
|
||||
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
|
||||
-0.5000+0.6882j])
|
||||
tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j,
|
||||
-0.1250+0.1720j])
|
||||
|
||||
Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is
|
||||
redundant. We can thus compute the forward transform without considering
|
||||
negative frequencies:
|
||||
|
||||
>>> torch.fft.hfft(T[:3], n=5)
|
||||
tensor([0., 1., 2., 3., 4.])
|
||||
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
|
||||
|
||||
Like with :func:`~torch.fft.irfft`, the output length must be given in order
|
||||
to recover an even length output:
|
||||
|
||||
>>> torch.fft.hfft(T[:3])
|
||||
tensor([0.5000, 1.1236, 2.5000, 3.8764])
|
||||
tensor([0.1250, 0.2809, 0.6250, 0.9691])
|
||||
""".format(**common_args))
|
||||
|
||||
ihfft = _add_docstr(_fft.fft_ihfft, r"""
|
||||
|
|
@ -802,13 +799,13 @@ Example:
|
|||
>>> t
|
||||
tensor([0, 1, 2, 3, 4])
|
||||
>>> torch.fft.ihfft(t)
|
||||
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
|
||||
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
|
||||
|
||||
Compare against the full output from :func:`~torch.fft.ifft`:
|
||||
|
||||
>>> torch.fft.ifft(t)
|
||||
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
|
||||
-0.5000+0.6882j])
|
||||
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
|
||||
-0.5000+0.6882j])
|
||||
""".format(**common_args))
|
||||
|
||||
fftfreq = _add_docstr(_fft.fft_fftfreq, r"""
|
||||
|
|
@ -891,10 +888,10 @@ Keyword Args:
|
|||
Example:
|
||||
|
||||
>>> torch.fft.rfftfreq(5)
|
||||
tensor([ 0.0000, 0.2000, 0.4000])
|
||||
tensor([0.0000, 0.2000, 0.4000])
|
||||
|
||||
>>> torch.fft.rfftfreq(4)
|
||||
tensor([ 0.0000, 0.2500, 0.5000])
|
||||
tensor([0.0000, 0.2500, 0.5000])
|
||||
|
||||
Compared to the output from :func:`~torch.fft.fftfreq`, we see that the
|
||||
Nyquist frequency at ``f[2]`` has changed sign:
|
||||
|
|
@ -980,8 +977,7 @@ Example:
|
|||
data, can be performed by applying the inverse shifts in reverse order:
|
||||
|
||||
>>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered)))
|
||||
>>> torch.allclose(x_centered.to(torch.complex64), x_centered_2)
|
||||
True
|
||||
>>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False)
|
||||
|
||||
|
||||
""")
|
||||
|
|
@ -1007,8 +1003,8 @@ Example:
|
|||
A round-trip through :func:`~torch.fft.fftshift` and
|
||||
:func:`~torch.fft.ifftshift` gives the same result:
|
||||
|
||||
>>> shifted = torch.fftshift(f)
|
||||
>>> torch.ifftshift(shifted)
|
||||
>>> shifted = torch.fft.fftshift(f)
|
||||
>>> torch.fft.ifftshift(shifted)
|
||||
tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000])
|
||||
|
||||
""")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user