[complex32] fft support (cuda only) (#74857)

`half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}`

* We only add support for `CUDA` as `cuFFT` supports these precision.
* We still error out on `CPU` and `ROCm` as their respective backends don't support this precision

For `cuFFT` following are the constraints for these precisions
* Minimum GPU architecture is SM_53
* Sizes are restricted to powers of two only
* Strides on the real part of real-to-complex and complex-to-real transforms are not supported
* More than one GPU is not supported
* Transforms spanning more than 4 billion elements are not supported

Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms

TODO:
* [x] Update docs about the restrictions
* [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 )

Ref  for second point in TODO:e424e7d214/aten/src/ATen/native/SpectralOps.cpp (L31)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74857
Approved by: https://github.com/anjali411, https://github.com/peterbell10
This commit is contained in:
kshitij12345 2022-05-12 04:28:55 +00:00 committed by PyTorch MergeBot
parent b825e1d472
commit ada65fdd67
8 changed files with 351 additions and 55 deletions

View File

@ -139,6 +139,14 @@ bool CUDAHooks::hasCuSOLVER() const {
#endif #endif
} }
bool CUDAHooks::hasROCM() const {
// Currently, this is same as `compiledWithMIOpen`.
// But in future if there are ROCm builds without MIOpen,
// then `hasROCM` should return true while `compiledWithMIOpen`
// should return false
return AT_ROCM_ENABLED();
}
#if defined(USE_DIRECT_NVRTC) #if defined(USE_DIRECT_NVRTC)
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() { static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
return std::make_pair(nullptr, at::cuda::load_nvrtc()); return std::make_pair(nullptr, at::cuda::load_nvrtc());

View File

@ -27,6 +27,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasMAGMA() const override; bool hasMAGMA() const override;
bool hasCuDNN() const override; bool hasCuDNN() const override;
bool hasCuSOLVER() const override; bool hasCuSOLVER() const override;
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override; const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override; int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override; bool hasPrimaryContext(int64_t device_index) const override;

View File

@ -107,6 +107,10 @@ struct TORCH_API CUDAHooksInterface {
return false; return false;
} }
virtual bool hasROCM() const {
return false;
}
virtual const at::cuda::NVRTC& nvrtc() const { virtual const at::cuda::NVRTC& nvrtc() const {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
} }

View File

@ -19,7 +19,7 @@ namespace {
// * Integers are promoted to the default floating type // * Integers are promoted to the default floating type
// * If require_complex=True, all types are promoted to complex // * If require_complex=True, all types are promoted to complex
// * Raises an error for half-precision dtypes to allow future support // * Raises an error for half-precision dtypes to allow future support
ScalarType promote_type_fft(ScalarType type, bool require_complex) { ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) {
if (at::isComplexType(type)) { if (at::isComplexType(type)) {
return type; return type;
} }
@ -28,7 +28,11 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
type = c10::typeMetaToScalarType(c10::get_default_dtype()); type = c10::typeMetaToScalarType(c10::get_default_dtype());
} }
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type); if (device.is_cuda() && !at::detail::getCUDAHooks().hasROCM()) {
TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type);
} else {
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
}
if (!require_complex) { if (!require_complex) {
return type; return type;
@ -36,6 +40,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
// Promote to complex // Promote to complex
switch (type) { switch (type) {
case kHalf: return kComplexHalf;
case kFloat: return kComplexFloat; case kFloat: return kComplexFloat;
case kDouble: return kComplexDouble; case kDouble: return kComplexDouble;
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype"); default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
@ -45,7 +50,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
// Promote a tensor's dtype according to promote_type_fft // Promote a tensor's dtype according to promote_type_fft
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) { Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
auto cur_type = t.scalar_type(); auto cur_type = t.scalar_type();
auto new_type = promote_type_fft(cur_type, require_complex); auto new_type = promote_type_fft(cur_type, require_complex, t.device());
return (cur_type == new_type) ? t : t.to(new_type); return (cur_type == new_type) ? t : t.to(new_type);
} }

View File

@ -106,17 +106,17 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
signal_half_sizes, out_strides, mirror_dims, element_size); signal_half_sizes, out_strides, mirror_dims, element_size);
const auto numel = c10::multiply_integers(signal_half_sizes); const auto numel = c10::multiply_integers(signal_half_sizes);
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] { AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
using namespace cuda::detail; using namespace cuda::detail;
_fft_conjugate_copy_kernel<<< _fft_conjugate_copy_kernel<<<
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
numel, numel,
static_cast<scalar_t*>(out_data), static_cast<scalar_t*>(out_data),
static_cast<const scalar_t*>(in_data), static_cast<const scalar_t*>(in_data),
input_offset_calculator, input_offset_calculator,
output_offset_calculator); output_offset_calculator);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}); });
} }
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_); REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);

View File

@ -10,12 +10,14 @@ import doctest
import inspect import inspect
from torch.testing._internal.common_utils import \ from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL) (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
make_tensor)
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf) skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
from torch.testing._internal.common_methods_invocations import ( from torch.testing._internal.common_methods_invocations import (
spectral_funcs, SpectralFuncType) spectral_funcs, SpectralFuncType)
from torch.testing._internal.common_cuda import SM53OrLater
from setuptools import distutils from setuptools import distutils
from typing import Optional, List from typing import Optional, List
@ -110,6 +112,20 @@ def _stft_reference(x, hop_length, window):
X[:, m] = torch.fft.fft(slc * window) X[:, m] = torch.fft.fft(slc * window)
return X return X
def skip_helper_for_fft(device, dtype):
device_type = torch.device(device).type
if dtype not in (torch.half, torch.complex32):
return
if device_type == 'cpu':
raise unittest.SkipTest("half and complex32 are not supported on CPU")
if TEST_WITH_ROCM:
raise unittest.SkipTest("half and complex32 are not supported on ROCM")
if not SM53OrLater:
raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
# Tests of functions related to Fourier analysis in the torch.fft namespace # Tests of functions related to Fourier analysis in the torch.fft namespace
class TestFFT(TestCase): class TestFFT(TestCase):
exact_dtype = True exact_dtype = True
@ -157,20 +173,39 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT @skipCPUIfNoFFT
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128) @toleranceOverride({
torch.half : tol(1e-2, 1e-2),
torch.chalf : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
def test_fft_round_trip(self, device, dtype): def test_fft_round_trip(self, device, dtype):
skip_helper_for_fft(device, dtype)
# Test that round trip through ifft(fft(x)) is the identity # Test that round trip through ifft(fft(x)) is the identity
test_args = list(product( if dtype not in (torch.half, torch.complex32):
# input test_args = list(product(
(torch.randn(67, device=device, dtype=dtype), # input
torch.randn(80, device=device, dtype=dtype), (torch.randn(67, device=device, dtype=dtype),
torch.randn(12, 14, device=device, dtype=dtype), torch.randn(80, device=device, dtype=dtype),
torch.randn(9, 6, 3, device=device, dtype=dtype)), torch.randn(12, 14, device=device, dtype=dtype),
# dim torch.randn(9, 6, 3, device=device, dtype=dtype)),
(-1, 0), # dim
# norm (-1, 0),
(None, "forward", "backward", "ortho") # norm
)) (None, "forward", "backward", "ortho")
))
else:
# cuFFT supports powers of 2 for half and complex half precision
test_args = list(product(
# input
(torch.randn(64, device=device, dtype=dtype),
torch.randn(128, device=device, dtype=dtype),
torch.randn(4, 16, device=device, dtype=dtype),
torch.randn(8, 6, 2, device=device, dtype=dtype)),
# dim
(-1, 0),
# norm
(None, "forward", "backward", "ortho")
))
fft_functions = [(torch.fft.fft, torch.fft.ifft)] fft_functions = [(torch.fft.fft, torch.fft.ifft)]
# Real-only functions # Real-only functions
@ -189,13 +224,17 @@ class TestFFT(TestCase):
} }
y = backward(forward(x, **kwargs), **kwargs) y = backward(forward(x, **kwargs), **kwargs)
if x.dtype is torch.half and y.dtype is torch.complex32:
# Since type promotion currently doesn't work with complex32
# manually promote `x` to complex32
x = x.to(torch.complex32)
# For real input, ifft(fft(x)) will convert to complex # For real input, ifft(fft(x)) will convert to complex
self.assertEqual(x, y, exact_dtype=( self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fft or x.is_complex())) forward != torch.fft.fft or x.is_complex()))
# Note: NumPy will throw a ValueError for an empty input # Note: NumPy will throw a ValueError for an empty input
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ops(spectral_funcs, allowed_dtypes=(torch.float, torch.cfloat)) @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
def test_empty_fft(self, device, dtype, op): def test_empty_fft(self, device, dtype, op):
t = torch.empty(1, 0, device=device, dtype=dtype) t = torch.empty(1, 0, device=device, dtype=dtype)
match = r"Invalid number of data points \([-\d]*\) specified" match = r"Invalid number of data points \([-\d]*\) specified"
@ -228,8 +267,11 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT @skipCPUIfNoFFT
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128) @dtypes(torch.int8, torch.half, torch.float, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_fft_type_promotion(self, device, dtype): def test_fft_type_promotion(self, device, dtype):
skip_helper_for_fft(device, dtype)
if dtype.is_complex or dtype.is_floating_point: if dtype.is_complex or dtype.is_floating_point:
t = torch.randn(64, device=device, dtype=dtype) t = torch.randn(64, device=device, dtype=dtype)
else: else:
@ -237,8 +279,10 @@ class TestFFT(TestCase):
PROMOTION_MAP = { PROMOTION_MAP = {
torch.int8: torch.complex64, torch.int8: torch.complex64,
torch.half: torch.complex32,
torch.float: torch.complex64, torch.float: torch.complex64,
torch.double: torch.complex128, torch.double: torch.complex128,
torch.complex32: torch.complex32,
torch.complex64: torch.complex64, torch.complex64: torch.complex64,
torch.complex128: torch.complex128, torch.complex128: torch.complex128,
} }
@ -247,17 +291,27 @@ class TestFFT(TestCase):
PROMOTION_MAP_C2R = { PROMOTION_MAP_C2R = {
torch.int8: torch.float, torch.int8: torch.float,
torch.half: torch.half,
torch.float: torch.float, torch.float: torch.float,
torch.double: torch.double, torch.double: torch.double,
torch.complex32: torch.half,
torch.complex64: torch.float, torch.complex64: torch.float,
torch.complex128: torch.double, torch.complex128: torch.double,
} }
R = torch.fft.hfft(t) if dtype in (torch.half, torch.complex32):
# cuFFT supports powers of 2 for half and complex half precision
# NOTE: With hfft and default args where output_size n=2*(input_size - 1),
# we make sure that logical fft size is a power of two.
x = torch.randn(65, device=device, dtype=dtype)
R = torch.fft.hfft(x)
else:
R = torch.fft.hfft(t)
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
if not dtype.is_complex: if not dtype.is_complex:
PROMOTION_MAP_R2C = { PROMOTION_MAP_R2C = {
torch.int8: torch.complex64, torch.int8: torch.complex64,
torch.half: torch.complex32,
torch.float: torch.complex64, torch.float: torch.complex64,
torch.double: torch.complex128, torch.double: torch.complex128,
} }
@ -269,9 +323,32 @@ class TestFFT(TestCase):
allowed_dtypes=[torch.half, torch.bfloat16]) allowed_dtypes=[torch.half, torch.bfloat16])
def test_fft_half_and_bfloat16_errors(self, device, dtype, op): def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
# TODO: Remove torch.half error when complex32 is fully implemented # TODO: Remove torch.half error when complex32 is fully implemented
x = torch.randn(8, 8, device=device).to(dtype) sample = first_sample(self, op.sample_inputs(device, dtype))
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "): device_type = torch.device(device).type
op(x) if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
err_msg = "Unsupported dtype "
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
else:
err_msg = "Unsupported dtype "
with self.assertRaisesRegex(RuntimeError, err_msg):
op(sample.input, *sample.args, **sample.kwargs)
@onlyNativeDeviceTypes
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
t = make_tensor(13, 13, device=device, dtype=dtype)
err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
with self.assertRaisesRegex(RuntimeError, err_msg):
op(t)
if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
kwargs = {'s': (12, 12)}
else:
kwargs = {'n': 12}
with self.assertRaisesRegex(RuntimeError, err_msg):
op(t, **kwargs)
# nd-fft tests # nd-fft tests
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ -308,8 +385,15 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT @skipCPUIfNoFFT
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128) @toleranceOverride({
torch.half : tol(1e-2, 1e-2),
torch.chalf : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_fftn_round_trip(self, device, dtype): def test_fftn_round_trip(self, device, dtype):
skip_helper_for_fft(device, dtype)
norm_modes = (None, "forward", "backward", "ortho") norm_modes = (None, "forward", "backward", "ortho")
# input_ndim, dim # input_ndim, dim
@ -331,7 +415,11 @@ class TestFFT(TestCase):
(torch.fft.ihfftn, torch.fft.hfftn)] (torch.fft.ihfftn, torch.fft.hfftn)]
for input_ndim, dim in transform_desc: for input_ndim, dim in transform_desc:
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) if dtype in (torch.half, torch.complex32):
# cuFFT supports powers of 2 for half and complex half precision
shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
else:
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
x = torch.randn(*shape, device=device, dtype=dtype) x = torch.randn(*shape, device=device, dtype=dtype)
for (forward, backward), norm in product(fft_functions, norm_modes): for (forward, backward), norm in product(fft_functions, norm_modes):
@ -343,8 +431,13 @@ class TestFFT(TestCase):
kwargs = {'s': s, 'dim': dim, 'norm': norm} kwargs = {'s': s, 'dim': dim, 'norm': norm}
y = backward(forward(x, **kwargs), **kwargs) y = backward(forward(x, **kwargs), **kwargs)
# For real input, ifftn(fftn(x)) will convert to complex # For real input, ifftn(fftn(x)) will convert to complex
self.assertEqual(x, y, exact_dtype=( if x.dtype is torch.half and y.dtype is torch.chalf:
forward != torch.fft.fftn or x.is_complex())) # Since type promotion currently doesn't work with complex32
# manually promote `x` to complex32
self.assertEqual(x.to(torch.chalf), y)
else:
self.assertEqual(x, y, exact_dtype=(
forward != torch.fft.fftn or x.is_complex()))
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
@ -369,8 +462,13 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT @skipCPUIfNoFFT
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(torch.float, torch.double) @toleranceOverride({
torch.half : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double)
def test_hfftn(self, device, dtype): def test_hfftn(self, device, dtype):
skip_helper_for_fft(device, dtype)
# input_ndim, dim # input_ndim, dim
transform_desc = [ transform_desc = [
*product(range(2, 5), (None, (0,), (0, -1))), *product(range(2, 5), (None, (0,), (0, -1))),
@ -383,8 +481,10 @@ class TestFFT(TestCase):
for input_ndim, dim in transform_desc: for input_ndim, dim in transform_desc:
actual_dims = list(range(input_ndim)) if dim is None else dim actual_dims = list(range(input_ndim)) if dim is None else dim
if dtype is torch.half:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
else:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
expect = torch.randn(*shape, device=device, dtype=dtype) expect = torch.randn(*shape, device=device, dtype=dtype)
input = torch.fft.ifftn(expect, dim=dim, norm="ortho") input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
@ -401,8 +501,13 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT @skipCPUIfNoFFT
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(torch.float, torch.double) @toleranceOverride({
torch.half : tol(1e-2, 1e-2),
})
@dtypes(torch.half, torch.float, torch.double)
def test_ihfftn(self, device, dtype): def test_ihfftn(self, device, dtype):
skip_helper_for_fft(device, dtype)
# input_ndim, dim # input_ndim, dim
transform_desc = [ transform_desc = [
*product(range(2, 5), (None, (0,), (0, -1))), *product(range(2, 5), (None, (0,), (0, -1))),
@ -414,7 +519,11 @@ class TestFFT(TestCase):
] ]
for input_ndim, dim in transform_desc: for input_ndim, dim in transform_desc:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) if dtype is torch.half:
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
else:
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
input = torch.randn(*shape, device=device, dtype=dtype) input = torch.randn(*shape, device=device, dtype=dtype)
expect = torch.fft.ifftn(input, dim=dim, norm="ortho") expect = torch.fft.ifftn(input, dim=dim, norm="ortho")

View File

@ -20,7 +20,6 @@ fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
Computes the one dimensional discrete Fourier transform of :attr:`input`. Computes the one dimensional discrete Fourier transform of :attr:`input`.
Note: Note:
The Fourier domain representation of any real signal satisfies the The Fourier domain representation of any real signal satisfies the
Hermitian property: `X[i] = conj(X[-i])`. This function always returns both Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
the positive and negative frequency terms even though, for real inputs, the the positive and negative frequency terms even though, for real inputs, the
@ -28,6 +27,10 @@ Note:
more compact one-sided representation where only the positive frequencies more compact one-sided representation where only the positive frequencies
are returned. are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded n (int, optional): Signal length. If given, the input will either be zero-padded
@ -68,6 +71,10 @@ ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded n (int, optional): Signal length. If given, the input will either be zero-padded
@ -111,6 +118,10 @@ Note:
:func:`~torch.fft.rfft2` returns the more compact one-sided representation :func:`~torch.fft.rfft2` returns the more compact one-sided representation
where only the positive frequencies of the last dimension are returned. where only the positive frequencies of the last dimension are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -157,6 +168,10 @@ ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default. Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -203,7 +218,6 @@ fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
Computes the N dimensional discrete Fourier transform of :attr:`input`. Computes the N dimensional discrete Fourier transform of :attr:`input`.
Note: Note:
The Fourier domain representation of any real signal satisfies the The Fourier domain representation of any real signal satisfies the
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
function always returns all positive and negative frequency terms even function always returns all positive and negative frequency terms even
@ -211,6 +225,10 @@ Note:
:func:`~torch.fft.rfftn` returns the more compact one-sided representation :func:`~torch.fft.rfftn` returns the more compact one-sided representation
where only the positive frequencies of the last dimension are returned. where only the positive frequencies of the last dimension are returned.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -256,6 +274,10 @@ ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -305,6 +327,10 @@ The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
the output contains only the positive frequencies below the Nyquist frequency. the output contains only the positive frequencies below the Nyquist frequency.
To compute the full output, use :func:`~torch.fft.fft` To compute the full output, use :func:`~torch.fft.fft`
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args: Args:
input (Tensor): the real input tensor input (Tensor): the real input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded n (int, optional): Signal length. If given, the input will either be zero-padded
@ -367,6 +393,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal length :attr:`n`. properly. So, it is recommended to always pass the signal length :attr:`n`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
Args: Args:
input (Tensor): the input tensor representing a half-Hermitian signal input (Tensor): the input tensor representing a half-Hermitian signal
n (int, optional): Output signal length. This determines the length of the n (int, optional): Output signal length. This determines the length of the
@ -424,6 +456,10 @@ so the full :func:`~torch.fft.fft2` output contains redundant information.
:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last :func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
dimension. dimension.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -496,6 +532,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal shape :attr:`s`. properly. So, it is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -557,6 +599,10 @@ The FFT of a real signal is Hermitian-symmetric,
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the :func:`~torch.fft.rfftn` instead omits the negative frequencies in the
last dimension. last dimension.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -628,6 +674,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal shape :attr:`s`. properly. So, it is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -709,6 +761,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip signal is assumed to be even length and odd signals will not round-trip
properly. So, it is recommended to always pass the signal length :attr:`n`. properly. So, it is recommended to always pass the signal length :attr:`n`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
Args: Args:
input (Tensor): the input tensor representing a half-Hermitian signal input (Tensor): the input tensor representing a half-Hermitian signal
n (int, optional): Output signal length. This determines the length of the n (int, optional): Output signal length. This determines the length of the
@ -771,6 +829,10 @@ The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
positive frequencies below the Nyquist frequency are included. To compute the positive frequencies below the Nyquist frequency are included. To compute the
full output, use :func:`~torch.fft.ifft`. full output, use :func:`~torch.fft.ifft`.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimension.
Args: Args:
input (Tensor): the real input tensor input (Tensor): the real input tensor
n (int, optional): Signal length. If given, the input will either be zero-padded n (int, optional): Signal length. If given, the input will either be zero-padded
@ -818,6 +880,12 @@ transforms the last two dimensions by default.
:attr:`input` is interpreted as a one-sided Hermitian signal in the time :attr:`input` is interpreted as a one-sided Hermitian signal in the time
domain. By the Hermitian property, the Fourier transform will be real-valued. domain. By the Hermitian property, the Fourier transform will be real-valued.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -878,6 +946,10 @@ Computes the 2-dimensional inverse discrete Fourier transform of real
:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the :attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
two last dimensions by default. two last dimensions by default.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -960,6 +1032,12 @@ Note:
signal is assumed to be even length and odd signals will not round-trip signal is assumed to be even length and odd signals will not round-trip
properly. It is recommended to always pass the signal shape :attr:`s`. properly. It is recommended to always pass the signal shape :attr:`s`.
Note:
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
With default arguments, the size of last dimension should be (2^n + 1) as argument
`s` defaults to even output size = 2 * (last_dim_size - 1)
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.
@ -1025,6 +1103,10 @@ this in the one-sided form where only the positive frequencies below the
Nyquist frequency are included in the last signal dimension. To compute the Nyquist frequency are included in the last signal dimension. To compute the
full output, use :func:`~torch.fft.ifftn`. full output, use :func:`~torch.fft.ifftn`.
Note:
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
However it only supports powers of 2 signal length in every transformed dimensions.
Args: Args:
input (Tensor): the input tensor input (Tensor): the input tensor
s (Tuple[int], optional): Signal size in the transformed dimensions. s (Tuple[int], optional): Signal size in the transformed dimensions.

View File

@ -5810,15 +5810,33 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
return wrapped_fn return wrapped_fn
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device, is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
dtype=dtype, requires_grad=requires_grad) if not is_fp16_or_chalf:
oned_tensor = partial(make_tensor, (31,), device=device, nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
dtype=dtype, requires_grad=requires_grad) dtype=dtype, requires_grad=requires_grad)
oned_tensor = partial(make_tensor, (31,), device=device,
dtype=dtype, requires_grad=requires_grad)
else:
# cuFFT supports powers of 2 for half and complex half precision
# NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
# where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
if self.name in ['fft.hfft', 'fft.irfft']:
shapes = ((2, 9, 9), (33,))
elif self.name in ['fft.hfft2', 'fft.irfft2']:
shapes = ((2, 8, 9), (33,))
elif self.name in ['fft.hfftn', 'fft.irfftn']:
shapes = ((2, 2, 33), (33,))
else:
shapes = ((2, 8, 16), (32,))
nd_tensor = partial(make_tensor, shapes[0], device=device,
dtype=dtype, requires_grad=requires_grad)
oned_tensor = partial(make_tensor, shapes[1], device=device,
dtype=dtype, requires_grad=requires_grad)
if self.ndimensional == SpectralFuncType.ND: if self.ndimensional == SpectralFuncType.ND:
return [ return [
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')), kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
@ -5832,11 +5850,11 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
elif self.ndimensional == SpectralFuncType.TwoD: elif self.ndimensional == SpectralFuncType.TwoD:
return [ return [
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')), kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(s=(6, 8))), kwargs=dict(s=(6, 8) if not is_fp16_or_chalf else (4, 8))),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(dim=0)), kwargs=dict(dim=0)),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
@ -5847,11 +5865,12 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
else: else:
return [ return [
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(n=10, dim=1, norm='ortho')), kwargs=dict(n=10 if not is_fp16_or_chalf else 8, dim=1, norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(norm='ortho')), kwargs=dict(norm='ortho')),
SampleInput(nd_tensor(), SampleInput(nd_tensor(),
kwargs=dict(n=7)), kwargs=dict(n=7 if not is_fp16_or_chalf else 8)
),
SampleInput(oned_tensor()), SampleInput(oned_tensor()),
*(SampleInput(nd_tensor(), *(SampleInput(nd_tensor(),
@ -5887,6 +5906,8 @@ class SpectralFuncInfo(OpInfo):
decorators = list(decorators) if decorators is not None else [] decorators = list(decorators) if decorators is not None else []
decorators += [ decorators += [
skipCPUIfNoFFT, skipCPUIfNoFFT,
DecorateInfo(toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
"TestCommon", "test_complex_half_reference_testing")
] ]
super().__init__(name=name, super().__init__(name=name,
@ -10628,6 +10649,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fft, ref=np.fft.fft,
ndimensional=SpectralFuncType.OneD, ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
), ),
@ -10636,6 +10661,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fft2, ref=np.fft.fft2,
ndimensional=SpectralFuncType.TwoD, ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
decorators=[precisionOverride( decorators=[precisionOverride(
@ -10646,6 +10675,10 @@ op_db: List[OpInfo] = [
ref=np.fft.fftn, ref=np.fft.fftn,
ndimensional=SpectralFuncType.ND, ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
decorators=[precisionOverride( decorators=[precisionOverride(
@ -10656,6 +10689,10 @@ op_db: List[OpInfo] = [
ref=np.fft.hfft, ref=np.fft.hfft,
ndimensional=SpectralFuncType.OneD, ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False), check_batched_gradgrad=False),
@ -10664,6 +10701,10 @@ op_db: List[OpInfo] = [
ref=scipy.fft.hfft2 if has_scipy_fft else None, ref=scipy.fft.hfft2 if has_scipy_fft else None,
ndimensional=SpectralFuncType.TwoD, ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False, check_batched_gradgrad=False,
@ -10677,6 +10718,10 @@ op_db: List[OpInfo] = [
ref=scipy.fft.hfftn if has_scipy_fft else None, ref=scipy.fft.hfftn if has_scipy_fft else None,
ndimensional=SpectralFuncType.ND, ndimensional=SpectralFuncType.ND,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_gradgrad=False, check_batched_gradgrad=False,
@ -10690,6 +10735,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfft, ref=np.fft.rfft,
ndimensional=SpectralFuncType.OneD, ndimensional=SpectralFuncType.OneD,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_grad=False, check_batched_grad=False,
@ -10701,6 +10749,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfft2, ref=np.fft.rfft2,
ndimensional=SpectralFuncType.TwoD, ndimensional=SpectralFuncType.TwoD,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_grad=False, check_batched_grad=False,
@ -10713,6 +10764,9 @@ op_db: List[OpInfo] = [
ref=np.fft.rfftn, ref=np.fft.rfftn,
ndimensional=SpectralFuncType.ND, ndimensional=SpectralFuncType.ND,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
check_batched_grad=False, check_batched_grad=False,
@ -10726,7 +10780,11 @@ op_db: List[OpInfo] = [
ndimensional=SpectralFuncType.OneD, ndimensional=SpectralFuncType.OneD,
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool)), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),),
SpectralFuncInfo('fft.ifft2', SpectralFuncInfo('fft.ifft2',
aten_name='fft_ifft2', aten_name='fft_ifft2',
ref=np.fft.ifft2, ref=np.fft.ifft2,
@ -10734,6 +10792,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
decorators=[ decorators=[
DecorateInfo( DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
@ -10746,6 +10808,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
decorators=[ decorators=[
DecorateInfo( DecorateInfo(
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}), precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
@ -10758,6 +10824,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
skips=( skips=(
), ),
check_batched_grad=False), check_batched_grad=False),
@ -10768,6 +10837,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
check_batched_grad=False, check_batched_grad=False,
check_batched_gradgrad=False, check_batched_gradgrad=False,
decorators=( decorators=(
@ -10784,6 +10856,9 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.bool), dtypes=all_types_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
check_batched_grad=False, check_batched_grad=False,
check_batched_gradgrad=False, check_batched_gradgrad=False,
decorators=[ decorators=[
@ -10802,6 +10877,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False), check_batched_gradgrad=False),
SpectralFuncInfo('fft.irfft2', SpectralFuncInfo('fft.irfft2',
aten_name='fft_irfft2', aten_name='fft_irfft2',
@ -10810,6 +10889,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False, check_batched_gradgrad=False,
decorators=[ decorators=[
DecorateInfo( DecorateInfo(
@ -10823,6 +10906,10 @@ op_db: List[OpInfo] = [
supports_forward_ad=True, supports_forward_ad=True,
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool), dtypes=all_types_and_complex_and(torch.bool),
# rocFFT doesn't support Half/Complex Half Precision FFT
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
check_batched_gradgrad=False, check_batched_gradgrad=False,
decorators=[ decorators=[
DecorateInfo( DecorateInfo(