Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45164
This PR implements `fft2`, `ifft2`, `rfft2` and `irfft2`. These are the last functions required for `torch.fft` to match `numpy.fft`. If you look at either NumPy or SciPy you'll see that the 2-dimensional variants are identical to `*fftn` in every way, except for the default value of `axes`. In fact you can even use `fft2` to do general n-dimensional transforms.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D24363639
Pulled By: mruberry
fbshipit-source-id: 95191b51a0f0b8e8e301b2c20672ed4304d02a57
Summary:
Ref https://github.com/pytorch/pytorch/issues/42175
As already noted in the `torch.fft` `gradcheck` tests, `gradcheck` isn't fully working for complex types yet and the function inputs need to be real. A similar workaround for `gradgradcheck` works, viewing the complex outputs as real before returning them makes `gradgradcheck` pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46004
Reviewed By: ngimel
Differential Revision: D24187000
Pulled By: mruberry
fbshipit-source-id: 33c2986b07bac282dff1bd4f2109beb70e47bf79
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44550
Part of the `torch.fft` work (gh-42175).
This adds n-dimensional transforms: `fftn`, `ifftn`, `rfftn` and `irfftn`.
This is aiming for correctness first, with the implementation on top of the existing `_fft_with_size` restrictions. I plan to follow up later with a more efficient rewrite that makes `_fft_with_size` work with arbitrary numbers of dimensions.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D23846032
Pulled By: mruberry
fbshipit-source-id: e6950aa8be438ec5cb95fb10bd7b8bc9ffb7d824
Summary:
Ref https://github.com/pytorch/pytorch/issues/42175, fixes https://github.com/pytorch/pytorch/issues/34797
This adds complex support to `torch.stft` and `torch.istft`. Note that there are really two issues with complex here: complex signals, and returning complex tensors.
## Complex signals and windows
`stft` currently assumes all signals are real and uses `rfft` with `onesided=True` by default. Similarly, `istft` always takes a complex fourier series and uses `irfft` to return real signals.
For `stft`, I now allow complex inputs and windows by calling the full `fft` if either are complex. If the user gives `onesided=True` and the signal is complex, then this doesn't work and raises an error instead. For `istft`, there's no way to automatically know what to do when `onesided=False` because that could either be a redundant representation of a real signal or a complex signal. So there, the user needs to pass the argument `return_complex=True` in order to use `ifft` and get a complex result back.
## stft returning complex tensors
The other issue is that `stft` returns a complex result, represented as a `(... X 2)` real tensor. I think ideally we want this to return proper complex tensors but to preserver BC I've had to add a `return_complex` argument to manage this transition. `return_complex` defaults to false for real inputs to preserve BC but defaults to True for complex inputs where there is no BC to consider.
In order to `return_complex` by default everywhere without a sudden BC-breaking change, a simple transition plan could be:
1. introduce `return_complex`, defaulted to false when BC is an issue but giving a warning. (this PR)
2. raise an error in cases where `return_complex` defaults to false, making it a required argument.
3. change `return_complex` default to true in all cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43886
Reviewed By: glaringlee
Differential Revision: D23760174
Pulled By: mruberry
fbshipit-source-id: 2fec4404f5d980ddd6bdd941a63852a555eb9147
Summary:
Per title. ROCm CI doesn't have MKL so this adds a couple missing test annotations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42701
Reviewed By: ngimel
Differential Revision: D22986273
Pulled By: mruberry
fbshipit-source-id: efa717e2e3771562e9e82d1f914e251918e96f64
Summary:
This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function.
Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python:
```
import torch.fft
t = torch.randn(128, dtype=torch.cdouble)
torch.fft.fft(t)
```
See https://github.com/pytorch/pytorch/issues/42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41911
Reviewed By: glaringlee
Differential Revision: D22941894
Pulled By: mruberry
fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
Summary:
In preparation for creating the new torch.fft namespace and NumPy-like fft functions, as well as supporting our goal of refactoring and reducing the size of test_torch.py, this PR creates a test suite for our spectral ops.
The existing spectral op tests from test_torch.py and test_cuda.py are moved to test_spectral_ops.py and updated to run under the device generic test framework.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42157
Reviewed By: albanD
Differential Revision: D22811096
Pulled By: mruberry
fbshipit-source-id: e5c50f0016ea6bb8b093cd6df2dbcef6db9bb6b6