pytorch/torch/_refs/fft.py
Peter Bell d796518485 [refs] Fix size check from #108360 (#109083)
PR #108360 uses the same default `last_dim_size` formula from complex-to-real (C2R) transforms for
complex-to-complex (C2C) and real-to-complex (R2C). However, this is not correct because for C2R
the input is only half the size of the full tensor, which is not the case for C2C and C2R.

This error is mostly benign since `last_dim_size` was only used for the `>= 1` condition which is
almost always met anyway.

For this PR I now use it as the argument to `_apply_norm` which makes it load-bearing for correctness
and so is thoroughly tested now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109083
Approved by: https://github.com/lezcano
2023-09-27 23:59:29 +00:00

591 lines
18 KiB
Python

import math
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import torch._prims as prims
import torch._prims_common as utils
from torch._decomp import register_decomposition
from torch._prims_common import DimsType, ShapeType, TensorLikeType
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
__all__ = [
# Transforms
"fft",
"fft2",
"fftn",
"hfft",
"hfft2",
"hfftn",
"rfft",
"rfft2",
"rfftn",
"ifft",
"ifft2",
"ifftn",
"ihfft",
"ihfft2",
"ihfftn",
"irfft",
"irfft2",
"irfftn",
# Helpers
"fftshift",
"ifftshift",
]
NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
_NORM_VALUES = {None, "forward", "backward", "ortho"}
aten = torch._ops.ops.aten
def _apply_norm(
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
) -> TensorLikeType:
"""Apply normalization to the un-normalized FFT result"""
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
if norm == "ortho":
return x * (1 / math.sqrt(signal_numel))
normalize = (not forward and (norm is None or norm == "backward")) or (
forward and norm == "forward"
)
return x * (1 / signal_numel) if normalize else x
def _promote_type_fft(
dtype: torch.dtype, require_complex: bool, device: torch.device
) -> torch.dtype:
"""Helper to promote a dtype to one supported by the FFT primitives"""
if dtype.is_complex:
return dtype
# Promote integral to default float type
if not dtype.is_floating_point:
dtype = torch.get_default_dtype()
allowed_types = [torch.float32, torch.float64]
maybe_support_half = device.type in ["cuda", "meta"] and not torch.version.hip
if maybe_support_half:
allowed_types.append(torch.float16)
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
if require_complex:
dtype = utils.corresponding_complex_dtype(dtype)
return dtype
def _maybe_promote_tensor_fft(
t: TensorLikeType, require_complex: bool = False
) -> TensorLikeType:
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
cur_type = t.dtype
new_type = _promote_type_fft(cur_type, require_complex, t.device)
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
def _resize_fft_input(
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
) -> TensorLikeType:
"""
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
either by zero-padding, or by slicing x starting from 0.
"""
assert len(dims) == len(sizes)
must_copy = False
x_sizes = x.shape
pad_amount = [0] * len(x_sizes) * 2
for i in range(len(dims)):
if sizes[i] == -1:
continue
if x_sizes[dims[i]] < sizes[i]:
must_copy = True
pad_idx = len(pad_amount) - 2 * dims[i] - 1
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
if x_sizes[dims[i]] > sizes[i]:
x = x.narrow(dims[i], 0, sizes[i])
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
def _fft_c2r(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for performing any complex to real FFT (irfft or hfft)"""
input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified",
)
if n is not None:
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
if forward:
input = torch.conj(input)
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
def _fft_r2c(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
onesided: bool,
) -> TensorLikeType:
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
torch._check(
not input.dtype.is_complex,
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
)
input = _maybe_promote_tensor_fft(input)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
dim_size = n if n is not None else input.shape[dim]
torch._check(
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
ret = _apply_norm(ret, norm, dim_size, forward)
return ret if forward else torch.conj(ret)
def _fft_c2c(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for performing any complex to complex FFT (fft or ifft)"""
torch._check(
input.dtype.is_complex,
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
dim_size = n if n is not None else input.shape[dim]
torch._check(
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
ret = prims.fft_c2c(input, dim=dims, forward=forward)
return _apply_norm(ret, norm, dim_size, forward)
@register_decomposition(aten.fft_fft)
@out_wrapper()
def fft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
if input.dtype.is_complex:
return _fft_c2c("fft", input, n, dim, norm, forward=True)
else:
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
@register_decomposition(aten.fft_ifft)
@out_wrapper()
def ifft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
if input.dtype.is_complex:
return _fft_c2c("ifft", input, n, dim, norm, forward=False)
else:
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
@register_decomposition(aten.fft_rfft)
@out_wrapper()
def rfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
@register_decomposition(aten.fft_irfft)
@out_wrapper()
def irfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_c2r("irfft", input, n, dim, norm, forward=False)
@register_decomposition(aten.fft_hfft)
@out_wrapper()
def hfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_c2r("hfft", input, n, dim, norm, forward=True)
@register_decomposition(aten.fft_ihfft)
@out_wrapper()
def ihfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
class _ShapeAndDims(NamedTuple):
shape: Tuple[int, ...]
dims: Tuple[int, ...]
def _canonicalize_fft_shape_and_dim_args(
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
) -> _ShapeAndDims:
"""Convert the shape and dim arguments into a canonical form where neither are optional"""
input_dim = input.ndim
input_sizes = input.shape
if dim is not None:
if not isinstance(dim, Sequence):
dim = (dim,)
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
# Check dims are unique
torch._check(
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
)
if shape is not None:
if not isinstance(shape, Sequence):
shape = (shape,)
# Has shape, might have dim
torch._check(
dim is None or len(dim) == len(shape),
lambda: "When given, dim and shape arguments must have the same length",
)
transform_ndim = len(shape)
torch._check(
transform_ndim <= input_dim,
lambda: f"Got shape with {transform_ndim} values but input tensor "
f"only has {input_dim} dimensions.",
)
# If shape is given, dims defaults to the last len(shape) dimensions
if dim is None:
ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
# Translate any -1 values in shape to the default length
ret_shape = tuple(
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
)
elif dim is None:
# No shape, no dim
ret_dims = tuple(range(input_dim))
ret_shape = tuple(input_sizes)
else:
# No shape, has dim
ret_shape = tuple(input_sizes[d] for d in ret_dims)
for n in ret_shape:
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
def _prod(xs: Iterable[int]) -> int:
"""Compute product of a list"""
prod = 1
for x in xs:
prod *= x
return prod
def _fftn_c2c(
function_name: str,
input: TensorLikeType,
shape: Tuple[int, ...],
dim: Tuple[int, ...],
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
torch._check(
input.dtype.is_complex,
lambda: f"{function_name} expects a complex input tensor, "
f"but got {input.dtype}",
)
x = _resize_fft_input(input, dim, shape)
output = prims.fft_c2c(x, dim=dim, forward=forward)
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
@register_decomposition(aten.fft_fftn)
@out_wrapper()
def fftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
x = _maybe_promote_tensor_fft(input, require_complex=True)
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
@register_decomposition(aten.fft_ifftn)
@out_wrapper()
def ifftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
x = _maybe_promote_tensor_fft(input, require_complex=True)
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
@register_decomposition(aten.fft_rfftn)
@out_wrapper()
def rfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
torch._check(
not input.dtype.is_complex,
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
)
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape)
out = prims.fft_r2c(input, dim=dim, onesided=True)
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
@register_decomposition(aten.fft_ihfftn)
@out_wrapper()
def ihfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
torch._check(
not input.dtype.is_complex,
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
)
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape)
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
if len(dim) == 1:
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
return prims.conj(tmp)
tmp = prims.conj_physical(tmp)
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
class _CanonicalizeC2rReturn(NamedTuple):
shape: Tuple[int, ...]
dim: Tuple[int, ...]
last_dim_size: int
def _canonicalize_fft_c2r_shape_and_dim_args(
fname: str,
input: TensorLikeType,
s: Optional[ShapeType],
dim: Optional[DimsType],
) -> _CanonicalizeC2rReturn:
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
if s is None or s[-1] == -1:
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
else:
last_dim_size = shape[-1]
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified",
)
shape_list = list(shape)
shape_list[-1] = last_dim_size // 2 + 1
return _CanonicalizeC2rReturn(
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
)
@register_decomposition(aten.fft_irfftn)
@out_wrapper()
def irfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
"irfftn", input, s, dim
)
input = _maybe_promote_tensor_fft(input, require_complex=True)
input = _resize_fft_input(input, dim, shape)
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
@register_decomposition(aten.fft_hfftn)
@out_wrapper()
def hfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
"hfftn", input, s, dim
)
input = _maybe_promote_tensor_fft(input, require_complex=True)
input = _resize_fft_input(input, dim, shape)
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
tmp = prims.conj_physical(tmp)
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
return _apply_norm(out, norm, last_dim_size, forward=True)
@register_decomposition(aten.fft_fft2)
@out_wrapper()
def fft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_ifft2)
@out_wrapper()
def ifft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_rfft2)
@out_wrapper()
def rfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_irfft2)
@out_wrapper()
def irfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_hfft2)
@out_wrapper()
def hfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_ihfft2)
@out_wrapper()
def ihfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
if dim is None:
return list(range(x.ndim))
elif not isinstance(dim, Sequence):
return [dim]
else:
return list(dim)
@register_decomposition(aten.fft_fftshift)
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
dims = _default_alldims(dim, input)
shift = [input.shape[d] // 2 for d in dims]
return torch.roll(input, shift, dims)
@register_decomposition(aten.fft_ifftshift)
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
dims = _default_alldims(dim, input)
shift = [(input.shape[d] + 1) // 2 for d in dims]
return torch.roll(input, shift, dims)