mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fft: Fix invalid shape error for complex-to-real transforms (#73012)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/72910
`last_dim_size` is the expected output size for the
Hermitian-compressed dimension and must be > 0. The confusingly named
`ld` represents the input's last dim size which is calculated as
`last_dim_size / 2 + 1` so could never be 0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73012
Reviewed By: ngimel
Differential Revision: D34387147
Pulled By: mruberry
fbshipit-source-id: 6b410088efe2a9e117a5c6d8beefda370363dbb0
(cherry picked from commit f8d771ed36)
This commit is contained in:
parent
16e2f5d291
commit
9ea6db4aca
|
|
@ -446,6 +446,7 @@ ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
|
||||||
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dims);
|
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dims);
|
||||||
TORCH_CHECK(desc.shape.size() > 0, fname, " must transform at least one axis");
|
TORCH_CHECK(desc.shape.size() > 0, fname, " must transform at least one axis");
|
||||||
|
|
||||||
|
// Expected output size of the hermitian-symmetric dimension
|
||||||
last_dim_size = [&] {
|
last_dim_size = [&] {
|
||||||
// Fixup default shape handling in the last dimension,
|
// Fixup default shape handling in the last dimension,
|
||||||
if (!s.has_value() || (s->back() == -1)) {
|
if (!s.has_value() || (s->back() == -1)) {
|
||||||
|
|
@ -454,9 +455,10 @@ ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
|
||||||
}
|
}
|
||||||
return desc.shape.back();
|
return desc.shape.back();
|
||||||
}();
|
}();
|
||||||
auto ld = last_dim_size / 2 + 1;
|
TORCH_CHECK(last_dim_size >= 1, "Invalid number of data points (", last_dim_size, ") specified");
|
||||||
desc.shape.back() = ld;
|
|
||||||
TORCH_CHECK(ld >= 1, "Invalid number of data points (", last_dim_size, ") specified");
|
// Expected input size of the complex-hermitian data
|
||||||
|
desc.shape.back() = last_dim_size / 2 + 1;
|
||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -296,6 +296,16 @@ class TestFFT(TestCase):
|
||||||
with self.assertRaisesRegex(RuntimeError, match):
|
with self.assertRaisesRegex(RuntimeError, match):
|
||||||
op(t)
|
op(t)
|
||||||
|
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
def test_empty_ifft(self, device):
|
||||||
|
t = torch.empty(2, 1, device=device, dtype=torch.complex64)
|
||||||
|
match = r"Invalid number of data points \([-\d]*\) specified"
|
||||||
|
|
||||||
|
for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
|
||||||
|
torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, match):
|
||||||
|
f(t)
|
||||||
|
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
def test_fft_invalid_dtypes(self, device):
|
def test_fft_invalid_dtypes(self, device):
|
||||||
t = torch.randn(64, device=device, dtype=torch.complex128)
|
t = torch.randn(64, device=device, dtype=torch.complex128)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user