mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[complex32] fft support (cuda only) (#74857)
`half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}`
* We only add support for `CUDA` as `cuFFT` supports these precision.
* We still error out on `CPU` and `ROCm` as their respective backends don't support this precision
For `cuFFT` following are the constraints for these precisions
* Minimum GPU architecture is SM_53
* Sizes are restricted to powers of two only
* Strides on the real part of real-to-complex and complex-to-real transforms are not supported
* More than one GPU is not supported
* Transforms spanning more than 4 billion elements are not supported
Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms
TODO:
* [x] Update docs about the restrictions
* [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 )
Ref for second point in TODO:e424e7d214/aten/src/ATen/native/SpectralOps.cpp (L31)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74857
Approved by: https://github.com/anjali411, https://github.com/peterbell10
This commit is contained in:
parent
b825e1d472
commit
ada65fdd67
|
|
@ -139,6 +139,14 @@ bool CUDAHooks::hasCuSOLVER() const {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CUDAHooks::hasROCM() const {
|
||||||
|
// Currently, this is same as `compiledWithMIOpen`.
|
||||||
|
// But in future if there are ROCm builds without MIOpen,
|
||||||
|
// then `hasROCM` should return true while `compiledWithMIOpen`
|
||||||
|
// should return false
|
||||||
|
return AT_ROCM_ENABLED();
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(USE_DIRECT_NVRTC)
|
#if defined(USE_DIRECT_NVRTC)
|
||||||
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
|
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
|
||||||
return std::make_pair(nullptr, at::cuda::load_nvrtc());
|
return std::make_pair(nullptr, at::cuda::load_nvrtc());
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||||
bool hasMAGMA() const override;
|
bool hasMAGMA() const override;
|
||||||
bool hasCuDNN() const override;
|
bool hasCuDNN() const override;
|
||||||
bool hasCuSOLVER() const override;
|
bool hasCuSOLVER() const override;
|
||||||
|
bool hasROCM() const override;
|
||||||
const at::cuda::NVRTC& nvrtc() const override;
|
const at::cuda::NVRTC& nvrtc() const override;
|
||||||
int64_t current_device() const override;
|
int64_t current_device() const override;
|
||||||
bool hasPrimaryContext(int64_t device_index) const override;
|
bool hasPrimaryContext(int64_t device_index) const override;
|
||||||
|
|
|
||||||
|
|
@ -107,6 +107,10 @@ struct TORCH_API CUDAHooksInterface {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual bool hasROCM() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
virtual const at::cuda::NVRTC& nvrtc() const {
|
virtual const at::cuda::NVRTC& nvrtc() const {
|
||||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ namespace {
|
||||||
// * Integers are promoted to the default floating type
|
// * Integers are promoted to the default floating type
|
||||||
// * If require_complex=True, all types are promoted to complex
|
// * If require_complex=True, all types are promoted to complex
|
||||||
// * Raises an error for half-precision dtypes to allow future support
|
// * Raises an error for half-precision dtypes to allow future support
|
||||||
ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) {
|
||||||
if (at::isComplexType(type)) {
|
if (at::isComplexType(type)) {
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
@ -28,7 +28,11 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||||
type = c10::typeMetaToScalarType(c10::get_default_dtype());
|
type = c10::typeMetaToScalarType(c10::get_default_dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
if (device.is_cuda() && !at::detail::getCUDAHooks().hasROCM()) {
|
||||||
|
TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
|
||||||
|
}
|
||||||
|
|
||||||
if (!require_complex) {
|
if (!require_complex) {
|
||||||
return type;
|
return type;
|
||||||
|
|
@ -36,6 +40,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||||
|
|
||||||
// Promote to complex
|
// Promote to complex
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case kHalf: return kComplexHalf;
|
||||||
case kFloat: return kComplexFloat;
|
case kFloat: return kComplexFloat;
|
||||||
case kDouble: return kComplexDouble;
|
case kDouble: return kComplexDouble;
|
||||||
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
|
default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
|
||||||
|
|
@ -45,7 +50,7 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex) {
|
||||||
// Promote a tensor's dtype according to promote_type_fft
|
// Promote a tensor's dtype according to promote_type_fft
|
||||||
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
|
Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
|
||||||
auto cur_type = t.scalar_type();
|
auto cur_type = t.scalar_type();
|
||||||
auto new_type = promote_type_fft(cur_type, require_complex);
|
auto new_type = promote_type_fft(cur_type, require_complex, t.device());
|
||||||
return (cur_type == new_type) ? t : t.to(new_type);
|
return (cur_type == new_type) ? t : t.to(new_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -106,17 +106,17 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
|
||||||
signal_half_sizes, out_strides, mirror_dims, element_size);
|
signal_half_sizes, out_strides, mirror_dims, element_size);
|
||||||
|
|
||||||
const auto numel = c10::multiply_integers(signal_half_sizes);
|
const auto numel = c10::multiply_integers(signal_half_sizes);
|
||||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
||||||
using namespace cuda::detail;
|
using namespace cuda::detail;
|
||||||
_fft_conjugate_copy_kernel<<<
|
_fft_conjugate_copy_kernel<<<
|
||||||
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
numel,
|
numel,
|
||||||
static_cast<scalar_t*>(out_data),
|
static_cast<scalar_t*>(out_data),
|
||||||
static_cast<const scalar_t*>(in_data),
|
static_cast<const scalar_t*>(in_data),
|
||||||
input_offset_calculator,
|
input_offset_calculator,
|
||||||
output_offset_calculator);
|
output_offset_calculator);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
|
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,14 @@ import doctest
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import \
|
from torch.testing._internal.common_utils import \
|
||||||
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL)
|
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
|
||||||
|
make_tensor)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
|
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
|
||||||
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf)
|
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
|
||||||
from torch.testing._internal.common_methods_invocations import (
|
from torch.testing._internal.common_methods_invocations import (
|
||||||
spectral_funcs, SpectralFuncType)
|
spectral_funcs, SpectralFuncType)
|
||||||
|
from torch.testing._internal.common_cuda import SM53OrLater
|
||||||
|
|
||||||
from setuptools import distutils
|
from setuptools import distutils
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
@ -110,6 +112,20 @@ def _stft_reference(x, hop_length, window):
|
||||||
X[:, m] = torch.fft.fft(slc * window)
|
X[:, m] = torch.fft.fft(slc * window)
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
def skip_helper_for_fft(device, dtype):
|
||||||
|
device_type = torch.device(device).type
|
||||||
|
if dtype not in (torch.half, torch.complex32):
|
||||||
|
return
|
||||||
|
|
||||||
|
if device_type == 'cpu':
|
||||||
|
raise unittest.SkipTest("half and complex32 are not supported on CPU")
|
||||||
|
if TEST_WITH_ROCM:
|
||||||
|
raise unittest.SkipTest("half and complex32 are not supported on ROCM")
|
||||||
|
if not SM53OrLater:
|
||||||
|
raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
|
||||||
|
|
||||||
|
|
||||||
# Tests of functions related to Fourier analysis in the torch.fft namespace
|
# Tests of functions related to Fourier analysis in the torch.fft namespace
|
||||||
class TestFFT(TestCase):
|
class TestFFT(TestCase):
|
||||||
exact_dtype = True
|
exact_dtype = True
|
||||||
|
|
@ -157,20 +173,39 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
@skipCPUIfNoFFT
|
@skipCPUIfNoFFT
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
@toleranceOverride({
|
||||||
|
torch.half : tol(1e-2, 1e-2),
|
||||||
|
torch.chalf : tol(1e-2, 1e-2),
|
||||||
|
})
|
||||||
|
@dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
|
||||||
def test_fft_round_trip(self, device, dtype):
|
def test_fft_round_trip(self, device, dtype):
|
||||||
|
skip_helper_for_fft(device, dtype)
|
||||||
# Test that round trip through ifft(fft(x)) is the identity
|
# Test that round trip through ifft(fft(x)) is the identity
|
||||||
test_args = list(product(
|
if dtype not in (torch.half, torch.complex32):
|
||||||
# input
|
test_args = list(product(
|
||||||
(torch.randn(67, device=device, dtype=dtype),
|
# input
|
||||||
torch.randn(80, device=device, dtype=dtype),
|
(torch.randn(67, device=device, dtype=dtype),
|
||||||
torch.randn(12, 14, device=device, dtype=dtype),
|
torch.randn(80, device=device, dtype=dtype),
|
||||||
torch.randn(9, 6, 3, device=device, dtype=dtype)),
|
torch.randn(12, 14, device=device, dtype=dtype),
|
||||||
# dim
|
torch.randn(9, 6, 3, device=device, dtype=dtype)),
|
||||||
(-1, 0),
|
# dim
|
||||||
# norm
|
(-1, 0),
|
||||||
(None, "forward", "backward", "ortho")
|
# norm
|
||||||
))
|
(None, "forward", "backward", "ortho")
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# cuFFT supports powers of 2 for half and complex half precision
|
||||||
|
test_args = list(product(
|
||||||
|
# input
|
||||||
|
(torch.randn(64, device=device, dtype=dtype),
|
||||||
|
torch.randn(128, device=device, dtype=dtype),
|
||||||
|
torch.randn(4, 16, device=device, dtype=dtype),
|
||||||
|
torch.randn(8, 6, 2, device=device, dtype=dtype)),
|
||||||
|
# dim
|
||||||
|
(-1, 0),
|
||||||
|
# norm
|
||||||
|
(None, "forward", "backward", "ortho")
|
||||||
|
))
|
||||||
|
|
||||||
fft_functions = [(torch.fft.fft, torch.fft.ifft)]
|
fft_functions = [(torch.fft.fft, torch.fft.ifft)]
|
||||||
# Real-only functions
|
# Real-only functions
|
||||||
|
|
@ -189,13 +224,17 @@ class TestFFT(TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
y = backward(forward(x, **kwargs), **kwargs)
|
y = backward(forward(x, **kwargs), **kwargs)
|
||||||
|
if x.dtype is torch.half and y.dtype is torch.complex32:
|
||||||
|
# Since type promotion currently doesn't work with complex32
|
||||||
|
# manually promote `x` to complex32
|
||||||
|
x = x.to(torch.complex32)
|
||||||
# For real input, ifft(fft(x)) will convert to complex
|
# For real input, ifft(fft(x)) will convert to complex
|
||||||
self.assertEqual(x, y, exact_dtype=(
|
self.assertEqual(x, y, exact_dtype=(
|
||||||
forward != torch.fft.fft or x.is_complex()))
|
forward != torch.fft.fft or x.is_complex()))
|
||||||
|
|
||||||
# Note: NumPy will throw a ValueError for an empty input
|
# Note: NumPy will throw a ValueError for an empty input
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@ops(spectral_funcs, allowed_dtypes=(torch.float, torch.cfloat))
|
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
|
||||||
def test_empty_fft(self, device, dtype, op):
|
def test_empty_fft(self, device, dtype, op):
|
||||||
t = torch.empty(1, 0, device=device, dtype=dtype)
|
t = torch.empty(1, 0, device=device, dtype=dtype)
|
||||||
match = r"Invalid number of data points \([-\d]*\) specified"
|
match = r"Invalid number of data points \([-\d]*\) specified"
|
||||||
|
|
@ -228,8 +267,11 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
@skipCPUIfNoFFT
|
@skipCPUIfNoFFT
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(torch.int8, torch.float, torch.double, torch.complex64, torch.complex128)
|
@dtypes(torch.int8, torch.half, torch.float, torch.double,
|
||||||
|
torch.complex32, torch.complex64, torch.complex128)
|
||||||
def test_fft_type_promotion(self, device, dtype):
|
def test_fft_type_promotion(self, device, dtype):
|
||||||
|
skip_helper_for_fft(device, dtype)
|
||||||
|
|
||||||
if dtype.is_complex or dtype.is_floating_point:
|
if dtype.is_complex or dtype.is_floating_point:
|
||||||
t = torch.randn(64, device=device, dtype=dtype)
|
t = torch.randn(64, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
|
|
@ -237,8 +279,10 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
PROMOTION_MAP = {
|
PROMOTION_MAP = {
|
||||||
torch.int8: torch.complex64,
|
torch.int8: torch.complex64,
|
||||||
|
torch.half: torch.complex32,
|
||||||
torch.float: torch.complex64,
|
torch.float: torch.complex64,
|
||||||
torch.double: torch.complex128,
|
torch.double: torch.complex128,
|
||||||
|
torch.complex32: torch.complex32,
|
||||||
torch.complex64: torch.complex64,
|
torch.complex64: torch.complex64,
|
||||||
torch.complex128: torch.complex128,
|
torch.complex128: torch.complex128,
|
||||||
}
|
}
|
||||||
|
|
@ -247,17 +291,27 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
PROMOTION_MAP_C2R = {
|
PROMOTION_MAP_C2R = {
|
||||||
torch.int8: torch.float,
|
torch.int8: torch.float,
|
||||||
|
torch.half: torch.half,
|
||||||
torch.float: torch.float,
|
torch.float: torch.float,
|
||||||
torch.double: torch.double,
|
torch.double: torch.double,
|
||||||
|
torch.complex32: torch.half,
|
||||||
torch.complex64: torch.float,
|
torch.complex64: torch.float,
|
||||||
torch.complex128: torch.double,
|
torch.complex128: torch.double,
|
||||||
}
|
}
|
||||||
R = torch.fft.hfft(t)
|
if dtype in (torch.half, torch.complex32):
|
||||||
|
# cuFFT supports powers of 2 for half and complex half precision
|
||||||
|
# NOTE: With hfft and default args where output_size n=2*(input_size - 1),
|
||||||
|
# we make sure that logical fft size is a power of two.
|
||||||
|
x = torch.randn(65, device=device, dtype=dtype)
|
||||||
|
R = torch.fft.hfft(x)
|
||||||
|
else:
|
||||||
|
R = torch.fft.hfft(t)
|
||||||
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
|
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
|
||||||
|
|
||||||
if not dtype.is_complex:
|
if not dtype.is_complex:
|
||||||
PROMOTION_MAP_R2C = {
|
PROMOTION_MAP_R2C = {
|
||||||
torch.int8: torch.complex64,
|
torch.int8: torch.complex64,
|
||||||
|
torch.half: torch.complex32,
|
||||||
torch.float: torch.complex64,
|
torch.float: torch.complex64,
|
||||||
torch.double: torch.complex128,
|
torch.double: torch.complex128,
|
||||||
}
|
}
|
||||||
|
|
@ -269,9 +323,32 @@ class TestFFT(TestCase):
|
||||||
allowed_dtypes=[torch.half, torch.bfloat16])
|
allowed_dtypes=[torch.half, torch.bfloat16])
|
||||||
def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
|
def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
|
||||||
# TODO: Remove torch.half error when complex32 is fully implemented
|
# TODO: Remove torch.half error when complex32 is fully implemented
|
||||||
x = torch.randn(8, 8, device=device).to(dtype)
|
sample = first_sample(self, op.sample_inputs(device, dtype))
|
||||||
with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "):
|
device_type = torch.device(device).type
|
||||||
op(x)
|
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
|
||||||
|
err_msg = "Unsupported dtype "
|
||||||
|
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
|
||||||
|
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
|
||||||
|
else:
|
||||||
|
err_msg = "Unsupported dtype "
|
||||||
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||||
|
op(sample.input, *sample.args, **sample.kwargs)
|
||||||
|
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
|
||||||
|
def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
|
||||||
|
t = make_tensor(13, 13, device=device, dtype=dtype)
|
||||||
|
err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||||
|
op(t)
|
||||||
|
|
||||||
|
if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
|
||||||
|
kwargs = {'s': (12, 12)}
|
||||||
|
else:
|
||||||
|
kwargs = {'n': 12}
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||||
|
op(t, **kwargs)
|
||||||
|
|
||||||
# nd-fft tests
|
# nd-fft tests
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
|
|
@ -308,8 +385,15 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
@skipCPUIfNoFFT
|
@skipCPUIfNoFFT
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
|
@toleranceOverride({
|
||||||
|
torch.half : tol(1e-2, 1e-2),
|
||||||
|
torch.chalf : tol(1e-2, 1e-2),
|
||||||
|
})
|
||||||
|
@dtypes(torch.half, torch.float, torch.double,
|
||||||
|
torch.complex32, torch.complex64, torch.complex128)
|
||||||
def test_fftn_round_trip(self, device, dtype):
|
def test_fftn_round_trip(self, device, dtype):
|
||||||
|
skip_helper_for_fft(device, dtype)
|
||||||
|
|
||||||
norm_modes = (None, "forward", "backward", "ortho")
|
norm_modes = (None, "forward", "backward", "ortho")
|
||||||
|
|
||||||
# input_ndim, dim
|
# input_ndim, dim
|
||||||
|
|
@ -331,7 +415,11 @@ class TestFFT(TestCase):
|
||||||
(torch.fft.ihfftn, torch.fft.hfftn)]
|
(torch.fft.ihfftn, torch.fft.hfftn)]
|
||||||
|
|
||||||
for input_ndim, dim in transform_desc:
|
for input_ndim, dim in transform_desc:
|
||||||
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
|
if dtype in (torch.half, torch.complex32):
|
||||||
|
# cuFFT supports powers of 2 for half and complex half precision
|
||||||
|
shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
|
||||||
|
else:
|
||||||
|
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
|
||||||
x = torch.randn(*shape, device=device, dtype=dtype)
|
x = torch.randn(*shape, device=device, dtype=dtype)
|
||||||
|
|
||||||
for (forward, backward), norm in product(fft_functions, norm_modes):
|
for (forward, backward), norm in product(fft_functions, norm_modes):
|
||||||
|
|
@ -343,8 +431,13 @@ class TestFFT(TestCase):
|
||||||
kwargs = {'s': s, 'dim': dim, 'norm': norm}
|
kwargs = {'s': s, 'dim': dim, 'norm': norm}
|
||||||
y = backward(forward(x, **kwargs), **kwargs)
|
y = backward(forward(x, **kwargs), **kwargs)
|
||||||
# For real input, ifftn(fftn(x)) will convert to complex
|
# For real input, ifftn(fftn(x)) will convert to complex
|
||||||
self.assertEqual(x, y, exact_dtype=(
|
if x.dtype is torch.half and y.dtype is torch.chalf:
|
||||||
forward != torch.fft.fftn or x.is_complex()))
|
# Since type promotion currently doesn't work with complex32
|
||||||
|
# manually promote `x` to complex32
|
||||||
|
self.assertEqual(x.to(torch.chalf), y)
|
||||||
|
else:
|
||||||
|
self.assertEqual(x, y, exact_dtype=(
|
||||||
|
forward != torch.fft.fftn or x.is_complex()))
|
||||||
|
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
|
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
|
||||||
|
|
@ -369,8 +462,13 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
@skipCPUIfNoFFT
|
@skipCPUIfNoFFT
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(torch.float, torch.double)
|
@toleranceOverride({
|
||||||
|
torch.half : tol(1e-2, 1e-2),
|
||||||
|
})
|
||||||
|
@dtypes(torch.half, torch.float, torch.double)
|
||||||
def test_hfftn(self, device, dtype):
|
def test_hfftn(self, device, dtype):
|
||||||
|
skip_helper_for_fft(device, dtype)
|
||||||
|
|
||||||
# input_ndim, dim
|
# input_ndim, dim
|
||||||
transform_desc = [
|
transform_desc = [
|
||||||
*product(range(2, 5), (None, (0,), (0, -1))),
|
*product(range(2, 5), (None, (0,), (0, -1))),
|
||||||
|
|
@ -383,8 +481,10 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
for input_ndim, dim in transform_desc:
|
for input_ndim, dim in transform_desc:
|
||||||
actual_dims = list(range(input_ndim)) if dim is None else dim
|
actual_dims = list(range(input_ndim)) if dim is None else dim
|
||||||
|
if dtype is torch.half:
|
||||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
|
||||||
|
else:
|
||||||
|
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||||
expect = torch.randn(*shape, device=device, dtype=dtype)
|
expect = torch.randn(*shape, device=device, dtype=dtype)
|
||||||
input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
|
input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
|
||||||
|
|
||||||
|
|
@ -401,8 +501,13 @@ class TestFFT(TestCase):
|
||||||
|
|
||||||
@skipCPUIfNoFFT
|
@skipCPUIfNoFFT
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(torch.float, torch.double)
|
@toleranceOverride({
|
||||||
|
torch.half : tol(1e-2, 1e-2),
|
||||||
|
})
|
||||||
|
@dtypes(torch.half, torch.float, torch.double)
|
||||||
def test_ihfftn(self, device, dtype):
|
def test_ihfftn(self, device, dtype):
|
||||||
|
skip_helper_for_fft(device, dtype)
|
||||||
|
|
||||||
# input_ndim, dim
|
# input_ndim, dim
|
||||||
transform_desc = [
|
transform_desc = [
|
||||||
*product(range(2, 5), (None, (0,), (0, -1))),
|
*product(range(2, 5), (None, (0,), (0, -1))),
|
||||||
|
|
@ -414,7 +519,11 @@ class TestFFT(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_ndim, dim in transform_desc:
|
for input_ndim, dim in transform_desc:
|
||||||
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
if dtype is torch.half:
|
||||||
|
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
|
||||||
|
else:
|
||||||
|
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
|
||||||
|
|
||||||
input = torch.randn(*shape, device=device, dtype=dtype)
|
input = torch.randn(*shape, device=device, dtype=dtype)
|
||||||
expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
|
expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
|
||||||
Computes the one dimensional discrete Fourier transform of :attr:`input`.
|
Computes the one dimensional discrete Fourier transform of :attr:`input`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
The Fourier domain representation of any real signal satisfies the
|
The Fourier domain representation of any real signal satisfies the
|
||||||
Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
|
Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
|
||||||
the positive and negative frequency terms even though, for real inputs, the
|
the positive and negative frequency terms even though, for real inputs, the
|
||||||
|
|
@ -28,6 +27,10 @@ Note:
|
||||||
more compact one-sided representation where only the positive frequencies
|
more compact one-sided representation where only the positive frequencies
|
||||||
are returned.
|
are returned.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||||
|
|
@ -68,6 +71,10 @@ ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
|
||||||
|
|
||||||
Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
|
Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||||
|
|
@ -111,6 +118,10 @@ Note:
|
||||||
:func:`~torch.fft.rfft2` returns the more compact one-sided representation
|
:func:`~torch.fft.rfft2` returns the more compact one-sided representation
|
||||||
where only the positive frequencies of the last dimension are returned.
|
where only the positive frequencies of the last dimension are returned.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -157,6 +168,10 @@ ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
|
||||||
Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
|
Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||||
Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
|
Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -203,7 +218,6 @@ fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
|
||||||
Computes the N dimensional discrete Fourier transform of :attr:`input`.
|
Computes the N dimensional discrete Fourier transform of :attr:`input`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
The Fourier domain representation of any real signal satisfies the
|
The Fourier domain representation of any real signal satisfies the
|
||||||
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
|
Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
|
||||||
function always returns all positive and negative frequency terms even
|
function always returns all positive and negative frequency terms even
|
||||||
|
|
@ -211,6 +225,10 @@ Note:
|
||||||
:func:`~torch.fft.rfftn` returns the more compact one-sided representation
|
:func:`~torch.fft.rfftn` returns the more compact one-sided representation
|
||||||
where only the positive frequencies of the last dimension are returned.
|
where only the positive frequencies of the last dimension are returned.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -256,6 +274,10 @@ ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
|
||||||
|
|
||||||
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
|
Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -305,6 +327,10 @@ The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
|
||||||
the output contains only the positive frequencies below the Nyquist frequency.
|
the output contains only the positive frequencies below the Nyquist frequency.
|
||||||
To compute the full output, use :func:`~torch.fft.fft`
|
To compute the full output, use :func:`~torch.fft.fft`
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the real input tensor
|
input (Tensor): the real input tensor
|
||||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||||
|
|
@ -367,6 +393,12 @@ Note:
|
||||||
signal is assumed to be even length and odd signals will not round-trip
|
signal is assumed to be even length and odd signals will not round-trip
|
||||||
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
|
||||||
|
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor representing a half-Hermitian signal
|
input (Tensor): the input tensor representing a half-Hermitian signal
|
||||||
n (int, optional): Output signal length. This determines the length of the
|
n (int, optional): Output signal length. This determines the length of the
|
||||||
|
|
@ -424,6 +456,10 @@ so the full :func:`~torch.fft.fft2` output contains redundant information.
|
||||||
:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
|
:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
|
||||||
dimension.
|
dimension.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -496,6 +532,12 @@ Note:
|
||||||
signal is assumed to be even length and odd signals will not round-trip
|
signal is assumed to be even length and odd signals will not round-trip
|
||||||
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||||
|
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -557,6 +599,10 @@ The FFT of a real signal is Hermitian-symmetric,
|
||||||
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
|
:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
|
||||||
last dimension.
|
last dimension.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -628,6 +674,12 @@ Note:
|
||||||
signal is assumed to be even length and odd signals will not round-trip
|
signal is assumed to be even length and odd signals will not round-trip
|
||||||
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
properly. So, it is recommended to always pass the signal shape :attr:`s`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||||
|
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -709,6 +761,12 @@ Note:
|
||||||
signal is assumed to be even length and odd signals will not round-trip
|
signal is assumed to be even length and odd signals will not round-trip
|
||||||
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
properly. So, it is recommended to always pass the signal length :attr:`n`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
With default arguments, size of the transformed dimension should be (2^n + 1) as argument
|
||||||
|
`n` defaults to even output size = 2 * (transformed_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor representing a half-Hermitian signal
|
input (Tensor): the input tensor representing a half-Hermitian signal
|
||||||
n (int, optional): Output signal length. This determines the length of the
|
n (int, optional): Output signal length. This determines the length of the
|
||||||
|
|
@ -771,6 +829,10 @@ The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
|
||||||
positive frequencies below the Nyquist frequency are included. To compute the
|
positive frequencies below the Nyquist frequency are included. To compute the
|
||||||
full output, use :func:`~torch.fft.ifft`.
|
full output, use :func:`~torch.fft.ifft`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the real input tensor
|
input (Tensor): the real input tensor
|
||||||
n (int, optional): Signal length. If given, the input will either be zero-padded
|
n (int, optional): Signal length. If given, the input will either be zero-padded
|
||||||
|
|
@ -818,6 +880,12 @@ transforms the last two dimensions by default.
|
||||||
:attr:`input` is interpreted as a one-sided Hermitian signal in the time
|
:attr:`input` is interpreted as a one-sided Hermitian signal in the time
|
||||||
domain. By the Hermitian property, the Fourier transform will be real-valued.
|
domain. By the Hermitian property, the Fourier transform will be real-valued.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||||
|
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -878,6 +946,10 @@ Computes the 2-dimensional inverse discrete Fourier transform of real
|
||||||
:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
|
:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
|
||||||
two last dimensions by default.
|
two last dimensions by default.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -960,6 +1032,12 @@ Note:
|
||||||
signal is assumed to be even length and odd signals will not round-trip
|
signal is assumed to be even length and odd signals will not round-trip
|
||||||
properly. It is recommended to always pass the signal shape :attr:`s`.
|
properly. It is recommended to always pass the signal shape :attr:`s`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
With default arguments, the size of last dimension should be (2^n + 1) as argument
|
||||||
|
`s` defaults to even output size = 2 * (last_dim_size - 1)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
@ -1025,6 +1103,10 @@ this in the one-sided form where only the positive frequencies below the
|
||||||
Nyquist frequency are included in the last signal dimension. To compute the
|
Nyquist frequency are included in the last signal dimension. To compute the
|
||||||
full output, use :func:`~torch.fft.ifftn`.
|
full output, use :func:`~torch.fft.ifftn`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports torch.half on CUDA with GPU Architecture SM53 or greater.
|
||||||
|
However it only supports powers of 2 signal length in every transformed dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): the input tensor
|
input (Tensor): the input tensor
|
||||||
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
s (Tuple[int], optional): Signal size in the transformed dimensions.
|
||||||
|
|
|
||||||
|
|
@ -5810,15 +5810,33 @@ def np_unary_ufunc_integer_promotion_wrapper(fn):
|
||||||
return wrapped_fn
|
return wrapped_fn
|
||||||
|
|
||||||
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
|
def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
|
||||||
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
|
is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
|
||||||
dtype=dtype, requires_grad=requires_grad)
|
if not is_fp16_or_chalf:
|
||||||
oned_tensor = partial(make_tensor, (31,), device=device,
|
nd_tensor = partial(make_tensor, (S, S + 1, S + 2), device=device,
|
||||||
dtype=dtype, requires_grad=requires_grad)
|
dtype=dtype, requires_grad=requires_grad)
|
||||||
|
oned_tensor = partial(make_tensor, (31,), device=device,
|
||||||
|
dtype=dtype, requires_grad=requires_grad)
|
||||||
|
else:
|
||||||
|
# cuFFT supports powers of 2 for half and complex half precision
|
||||||
|
# NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
|
||||||
|
# where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
|
||||||
|
if self.name in ['fft.hfft', 'fft.irfft']:
|
||||||
|
shapes = ((2, 9, 9), (33,))
|
||||||
|
elif self.name in ['fft.hfft2', 'fft.irfft2']:
|
||||||
|
shapes = ((2, 8, 9), (33,))
|
||||||
|
elif self.name in ['fft.hfftn', 'fft.irfftn']:
|
||||||
|
shapes = ((2, 2, 33), (33,))
|
||||||
|
else:
|
||||||
|
shapes = ((2, 8, 16), (32,))
|
||||||
|
nd_tensor = partial(make_tensor, shapes[0], device=device,
|
||||||
|
dtype=dtype, requires_grad=requires_grad)
|
||||||
|
oned_tensor = partial(make_tensor, shapes[1], device=device,
|
||||||
|
dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
if self.ndimensional == SpectralFuncType.ND:
|
if self.ndimensional == SpectralFuncType.ND:
|
||||||
return [
|
return [
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
|
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(norm='ortho')),
|
kwargs=dict(norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
|
|
@ -5832,11 +5850,11 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
|
||||||
elif self.ndimensional == SpectralFuncType.TwoD:
|
elif self.ndimensional == SpectralFuncType.TwoD:
|
||||||
return [
|
return [
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')),
|
kwargs=dict(s=(3, 10) if not is_fp16_or_chalf else (4, 8), dim=(1, 2), norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(norm='ortho')),
|
kwargs=dict(norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(s=(6, 8))),
|
kwargs=dict(s=(6, 8) if not is_fp16_or_chalf else (4, 8))),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(dim=0)),
|
kwargs=dict(dim=0)),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
|
|
@ -5847,11 +5865,12 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg
|
||||||
else:
|
else:
|
||||||
return [
|
return [
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(n=10, dim=1, norm='ortho')),
|
kwargs=dict(n=10 if not is_fp16_or_chalf else 8, dim=1, norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(norm='ortho')),
|
kwargs=dict(norm='ortho')),
|
||||||
SampleInput(nd_tensor(),
|
SampleInput(nd_tensor(),
|
||||||
kwargs=dict(n=7)),
|
kwargs=dict(n=7 if not is_fp16_or_chalf else 8)
|
||||||
|
),
|
||||||
SampleInput(oned_tensor()),
|
SampleInput(oned_tensor()),
|
||||||
|
|
||||||
*(SampleInput(nd_tensor(),
|
*(SampleInput(nd_tensor(),
|
||||||
|
|
@ -5887,6 +5906,8 @@ class SpectralFuncInfo(OpInfo):
|
||||||
decorators = list(decorators) if decorators is not None else []
|
decorators = list(decorators) if decorators is not None else []
|
||||||
decorators += [
|
decorators += [
|
||||||
skipCPUIfNoFFT,
|
skipCPUIfNoFFT,
|
||||||
|
DecorateInfo(toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
|
||||||
|
"TestCommon", "test_complex_half_reference_testing")
|
||||||
]
|
]
|
||||||
|
|
||||||
super().__init__(name=name,
|
super().__init__(name=name,
|
||||||
|
|
@ -10628,6 +10649,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.fft,
|
ref=np.fft.fft,
|
||||||
ndimensional=SpectralFuncType.OneD,
|
ndimensional=SpectralFuncType.OneD,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
),
|
),
|
||||||
|
|
@ -10636,6 +10661,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.fft2,
|
ref=np.fft.fft2,
|
||||||
ndimensional=SpectralFuncType.TwoD,
|
ndimensional=SpectralFuncType.TwoD,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
decorators=[precisionOverride(
|
decorators=[precisionOverride(
|
||||||
|
|
@ -10646,6 +10675,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.fftn,
|
ref=np.fft.fftn,
|
||||||
ndimensional=SpectralFuncType.ND,
|
ndimensional=SpectralFuncType.ND,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
decorators=[precisionOverride(
|
decorators=[precisionOverride(
|
||||||
|
|
@ -10656,6 +10689,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.hfft,
|
ref=np.fft.hfft,
|
||||||
ndimensional=SpectralFuncType.OneD,
|
ndimensional=SpectralFuncType.OneD,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_gradgrad=False),
|
check_batched_gradgrad=False),
|
||||||
|
|
@ -10664,6 +10701,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=scipy.fft.hfft2 if has_scipy_fft else None,
|
ref=scipy.fft.hfft2 if has_scipy_fft else None,
|
||||||
ndimensional=SpectralFuncType.TwoD,
|
ndimensional=SpectralFuncType.TwoD,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
|
|
@ -10677,6 +10718,10 @@ op_db: List[OpInfo] = [
|
||||||
ref=scipy.fft.hfftn if has_scipy_fft else None,
|
ref=scipy.fft.hfftn if has_scipy_fft else None,
|
||||||
ndimensional=SpectralFuncType.ND,
|
ndimensional=SpectralFuncType.ND,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
|
|
@ -10690,6 +10735,9 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.rfft,
|
ref=np.fft.rfft,
|
||||||
ndimensional=SpectralFuncType.OneD,
|
ndimensional=SpectralFuncType.OneD,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
|
|
@ -10701,6 +10749,9 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.rfft2,
|
ref=np.fft.rfft2,
|
||||||
ndimensional=SpectralFuncType.TwoD,
|
ndimensional=SpectralFuncType.TwoD,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
|
|
@ -10713,6 +10764,9 @@ op_db: List[OpInfo] = [
|
||||||
ref=np.fft.rfftn,
|
ref=np.fft.rfftn,
|
||||||
ndimensional=SpectralFuncType.ND,
|
ndimensional=SpectralFuncType.ND,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
|
|
@ -10726,7 +10780,11 @@ op_db: List[OpInfo] = [
|
||||||
ndimensional=SpectralFuncType.OneD,
|
ndimensional=SpectralFuncType.OneD,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool)),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),),
|
||||||
SpectralFuncInfo('fft.ifft2',
|
SpectralFuncInfo('fft.ifft2',
|
||||||
aten_name='fft_ifft2',
|
aten_name='fft_ifft2',
|
||||||
ref=np.fft.ifft2,
|
ref=np.fft.ifft2,
|
||||||
|
|
@ -10734,6 +10792,10 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
decorators=[
|
decorators=[
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
||||||
|
|
@ -10746,6 +10808,10 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
decorators=[
|
decorators=[
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
|
||||||
|
|
@ -10758,6 +10824,9 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
skips=(
|
skips=(
|
||||||
),
|
),
|
||||||
check_batched_grad=False),
|
check_batched_grad=False),
|
||||||
|
|
@ -10768,6 +10837,9 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
decorators=(
|
decorators=(
|
||||||
|
|
@ -10784,6 +10856,9 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and(torch.bool),
|
dtypes=all_types_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
|
||||||
|
dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)),
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
decorators=[
|
decorators=[
|
||||||
|
|
@ -10802,6 +10877,10 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
check_batched_gradgrad=False),
|
check_batched_gradgrad=False),
|
||||||
SpectralFuncInfo('fft.irfft2',
|
SpectralFuncInfo('fft.irfft2',
|
||||||
aten_name='fft_irfft2',
|
aten_name='fft_irfft2',
|
||||||
|
|
@ -10810,6 +10889,10 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
decorators=[
|
decorators=[
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
|
|
@ -10823,6 +10906,10 @@ op_db: List[OpInfo] = [
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=True,
|
supports_fwgrad_bwgrad=True,
|
||||||
dtypes=all_types_and_complex_and(torch.bool),
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
# rocFFT doesn't support Half/Complex Half Precision FFT
|
||||||
|
# CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(
|
||||||
|
torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
decorators=[
|
decorators=[
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user