fft: Fix invalid shape error for complex-to-real transforms (#73012)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/72910

`last_dim_size` is the expected output size for the
Hermitian-compressed dimension and must be > 0. The confusingly named
`ld` represents the input's last dim size which is calculated as
`last_dim_size / 2 + 1` so could never be 0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73012

Reviewed By: ngimel

Differential Revision: D34387147

Pulled By: mruberry

fbshipit-source-id: 6b410088efe2a9e117a5c6d8beefda370363dbb0
(cherry picked from commit f8d771ed36)
This commit is contained in:
Peter Bell 2022-02-22 19:05:54 -08:00 committed by PyTorch MergeBot
parent 16e2f5d291
commit 9ea6db4aca
2 changed files with 15 additions and 3 deletions

View File

@ -446,6 +446,7 @@ ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dims);
TORCH_CHECK(desc.shape.size() > 0, fname, " must transform at least one axis");
// Expected output size of the hermitian-symmetric dimension
last_dim_size = [&] {
// Fixup default shape handling in the last dimension,
if (!s.has_value() || (s->back() == -1)) {
@ -454,9 +455,10 @@ ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
}
return desc.shape.back();
}();
auto ld = last_dim_size / 2 + 1;
desc.shape.back() = ld;
TORCH_CHECK(ld >= 1, "Invalid number of data points (", last_dim_size, ") specified");
TORCH_CHECK(last_dim_size >= 1, "Invalid number of data points (", last_dim_size, ") specified");
// Expected input size of the complex-hermitian data
desc.shape.back() = last_dim_size / 2 + 1;
return desc;
}

View File

@ -296,6 +296,16 @@ class TestFFT(TestCase):
with self.assertRaisesRegex(RuntimeError, match):
op(t)
@onlyNativeDeviceTypes
def test_empty_ifft(self, device):
t = torch.empty(2, 1, device=device, dtype=torch.complex64)
match = r"Invalid number of data points \([-\d]*\) specified"
for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
with self.assertRaisesRegex(RuntimeError, match):
f(t)
@onlyNativeDeviceTypes
def test_fft_invalid_dtypes(self, device):
t = torch.randn(64, device=device, dtype=torch.complex128)