Fix FFT documentation examples and run doctests in the test suite (#60304)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/59514

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60304

Reviewed By: anjali411

Differential Revision: D29253980

Pulled By: mruberry

fbshipit-source-id: 0654f00197e5fae338aa8edf0b61ef5692cdaa7e
This commit is contained in:
Peter Bell 2021-06-21 20:45:03 -07:00 committed by Facebook GitHub Bot
parent 5921b5480a
commit 5d476f5b95
2 changed files with 84 additions and 38 deletions

View File

@ -4,6 +4,8 @@ import math
from contextlib import contextmanager
from itertools import product
import itertools
import doctest
import inspect
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
@ -1218,7 +1220,55 @@ class TestFFT(TestCase):
torch.istft(x.to(device), n_fft=100, window=window)
class FFTDocTestFinder:
'''The default doctest finder doesn't like that function.__module__ doesn't
match torch.fft. It assumes the functions are leaked imports.
'''
def __init__(self):
self.parser = doctest.DocTestParser()
def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
doctests = []
modname = name if name is not None else obj.__name__
globs = dict() if globs is None else globs
for fname in obj.__all__:
func = getattr(obj, fname)
if inspect.isroutine(func):
qualname = modname + '.' + fname
docstring = inspect.getdoc(func)
if docstring is None:
continue
examples = self.parser.get_doctest(
docstring, globs=globs, name=fname, filename=None, lineno=None)
doctests.append(examples)
return doctests
class TestFFTDocExamples(TestCase):
pass
def generate_doc_test(doc_test):
def test(self, device):
self.assertEqual(device, 'cpu')
runner = doctest.DocTestRunner()
runner.run(doc_test)
if runner.failures != 0:
runner.summarize()
self.fail('Doctest failed')
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoMkl(test))
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
generate_doc_test(doc_test)
instantiate_device_type_tests(TestFFT, globals())
instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
if __name__ == '__main__':
run_tests()

View File

@ -58,7 +58,7 @@ Example:
>>> torch.fft.fft(t)
tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
>>> t = tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
>>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
>>> torch.fft.fft(t)
tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
""".format(**common_args))
@ -147,7 +147,7 @@ Example:
here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.allclose(fft2, two_ffts)
>>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
""".format(**common_args))
@ -193,7 +193,7 @@ Example:
here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.allclose(ifft2, two_iffts)
>>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
""".format(**common_args))
@ -247,7 +247,7 @@ Example:
here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.allclose(fftn, two_ffts)
>>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
""".format(**common_args))
@ -292,7 +292,7 @@ Example:
here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.allclose(ifftn, two_iffts)
>>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
""".format(**common_args))
@ -393,23 +393,24 @@ Keyword args:
Example:
>>> t = torch.arange(5)
>>> t = torch.linspace(0, 1, 5)
>>> t
tensor([0, 1, 2, 3, 4])
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
>>> T = torch.fft.rfft(t)
>>> T
tensor([10.0000+0.0000j, -2.5000+3.4410j, -2.5000+0.8123j])
tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j])
Without specifying the output length to :func:`~torch.fft.irfft`, the output
will not round-trip properly because the input is odd-length:
>>> torch.fft.irfft(T)
tensor([0.6250, 1.4045, 3.1250, 4.8455])
tensor([0.1562, 0.3511, 0.7812, 1.2114])
So, it is recommended to always pass the signal length :attr:`n`:
>>> torch.fft.irfft(T, t.numel())
tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000])
>>> roundtrip = torch.fft.irfft(T, t.numel())
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(**common_args))
rfft2 = _add_docstr(_fft.fft_rfft2, r"""
@ -461,15 +462,14 @@ Example:
elements up to the Nyquist frequency.
>>> fft2 = torch.fft.fft2(t)
>>> torch.allclose(fft2[..., :6], rfft2)
True
>>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False)
The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2`
here is equivalent to a combination of :func:`~torch.fft.fft` and
:func:`~torch.fft.rfft`:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.allclose(rfft2, two_ffts)
>>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
""".format(**common_args))
@ -535,15 +535,14 @@ Example:
dimension:
>>> torch.fft.irfft2(T).size()
torch.Size([10, 10])
torch.Size([10, 8])
So, it is recommended to always pass the signal shape :attr:`s`.
>>> roundtrip = torch.fft.irfft2(T, t.size())
>>> roundtrip.size()
torch.Size([10, 9])
>>> torch.allclose(roundtrip, t)
True
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(**common_args))
@ -596,15 +595,14 @@ Example:
elements up to the Nyquist frequency.
>>> fftn = torch.fft.fftn(t)
>>> torch.allclose(fftn[..., :6], rfftn)
True
>>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False)
The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn`
here is equivalent to a combination of :func:`~torch.fft.fft` and
:func:`~torch.fft.rfft`:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.allclose(rfftn, two_ffts)
>>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
""".format(**common_args))
@ -669,15 +667,14 @@ Example:
dimension:
>>> torch.fft.irfftn(T).size()
torch.Size([10, 10])
torch.Size([10, 8])
So, it is recommended to always pass the signal shape :attr:`s`.
>>> roundtrip = torch.fft.irfftn(T, t.size())
>>> roundtrip.size()
torch.Size([10, 9])
>>> torch.allclose(roundtrip, t)
True
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(**common_args))
@ -741,26 +738,26 @@ Example:
Taking a real-valued frequency signal and bringing it into the time domain
gives Hermitian symmetric output:
>>> t = torch.arange(5)
>>> t = torch.linspace(0, 1, 5)
>>> t
tensor([0, 1, 2, 3, 4])
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
>>> T = torch.fft.ifft(t)
>>> T
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
-0.5000+0.6882j])
tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j,
-0.1250+0.1720j])
Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is
redundant. We can thus compute the forward transform without considering
negative frequencies:
>>> torch.fft.hfft(T[:3], n=5)
tensor([0., 1., 2., 3., 4.])
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
Like with :func:`~torch.fft.irfft`, the output length must be given in order
to recover an even length output:
>>> torch.fft.hfft(T[:3])
tensor([0.5000, 1.1236, 2.5000, 3.8764])
tensor([0.1250, 0.2809, 0.6250, 0.9691])
""".format(**common_args))
ihfft = _add_docstr(_fft.fft_ihfft, r"""
@ -802,13 +799,13 @@ Example:
>>> t
tensor([0, 1, 2, 3, 4])
>>> torch.fft.ihfft(t)
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
Compare against the full output from :func:`~torch.fft.ifft`:
>>> torch.fft.ifft(t)
tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
-0.5000+0.6882j])
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
-0.5000+0.6882j])
""".format(**common_args))
fftfreq = _add_docstr(_fft.fft_fftfreq, r"""
@ -891,10 +888,10 @@ Keyword Args:
Example:
>>> torch.fft.rfftfreq(5)
tensor([ 0.0000, 0.2000, 0.4000])
tensor([0.0000, 0.2000, 0.4000])
>>> torch.fft.rfftfreq(4)
tensor([ 0.0000, 0.2500, 0.5000])
tensor([0.0000, 0.2500, 0.5000])
Compared to the output from :func:`~torch.fft.fftfreq`, we see that the
Nyquist frequency at ``f[2]`` has changed sign:
@ -980,8 +977,7 @@ Example:
data, can be performed by applying the inverse shifts in reverse order:
>>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered)))
>>> torch.allclose(x_centered.to(torch.complex64), x_centered_2)
True
>>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False)
""")
@ -1007,8 +1003,8 @@ Example:
A round-trip through :func:`~torch.fft.fftshift` and
:func:`~torch.fft.ifftshift` gives the same result:
>>> shifted = torch.fftshift(f)
>>> torch.ifftshift(shifted)
>>> shifted = torch.fft.fftshift(f)
>>> torch.fft.ifftshift(shifted)
tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000])
""")