mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
PR #108360 uses the same default `last_dim_size` formula from complex-to-real (C2R) transforms for complex-to-complex (C2C) and real-to-complex (R2C). However, this is not correct because for C2R the input is only half the size of the full tensor, which is not the case for C2C and C2R. This error is mostly benign since `last_dim_size` was only used for the `>= 1` condition which is almost always met anyway. For this PR I now use it as the argument to `_apply_norm` which makes it load-bearing for correctness and so is thoroughly tested now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109083 Approved by: https://github.com/lezcano
591 lines
18 KiB
Python
591 lines
18 KiB
Python
import math
|
|
|
|
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch._prims as prims
|
|
import torch._prims_common as utils
|
|
from torch._decomp import register_decomposition
|
|
from torch._prims_common import DimsType, ShapeType, TensorLikeType
|
|
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
|
|
|
__all__ = [
|
|
# Transforms
|
|
"fft",
|
|
"fft2",
|
|
"fftn",
|
|
"hfft",
|
|
"hfft2",
|
|
"hfftn",
|
|
"rfft",
|
|
"rfft2",
|
|
"rfftn",
|
|
"ifft",
|
|
"ifft2",
|
|
"ifftn",
|
|
"ihfft",
|
|
"ihfft2",
|
|
"ihfftn",
|
|
"irfft",
|
|
"irfft2",
|
|
"irfftn",
|
|
# Helpers
|
|
"fftshift",
|
|
"ifftshift",
|
|
]
|
|
|
|
NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]]
|
|
_NORM_VALUES = {None, "forward", "backward", "ortho"}
|
|
aten = torch._ops.ops.aten
|
|
|
|
|
|
def _apply_norm(
|
|
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
|
) -> TensorLikeType:
|
|
"""Apply normalization to the un-normalized FFT result"""
|
|
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
|
|
|
if norm == "ortho":
|
|
return x * (1 / math.sqrt(signal_numel))
|
|
|
|
normalize = (not forward and (norm is None or norm == "backward")) or (
|
|
forward and norm == "forward"
|
|
)
|
|
return x * (1 / signal_numel) if normalize else x
|
|
|
|
|
|
def _promote_type_fft(
|
|
dtype: torch.dtype, require_complex: bool, device: torch.device
|
|
) -> torch.dtype:
|
|
"""Helper to promote a dtype to one supported by the FFT primitives"""
|
|
if dtype.is_complex:
|
|
return dtype
|
|
|
|
# Promote integral to default float type
|
|
if not dtype.is_floating_point:
|
|
dtype = torch.get_default_dtype()
|
|
|
|
allowed_types = [torch.float32, torch.float64]
|
|
maybe_support_half = device.type in ["cuda", "meta"] and not torch.version.hip
|
|
|
|
if maybe_support_half:
|
|
allowed_types.append(torch.float16)
|
|
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
|
|
|
|
if require_complex:
|
|
dtype = utils.corresponding_complex_dtype(dtype)
|
|
|
|
return dtype
|
|
|
|
|
|
def _maybe_promote_tensor_fft(
|
|
t: TensorLikeType, require_complex: bool = False
|
|
) -> TensorLikeType:
|
|
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
|
|
cur_type = t.dtype
|
|
new_type = _promote_type_fft(cur_type, require_complex, t.device)
|
|
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
|
|
|
|
|
|
def _resize_fft_input(
|
|
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
|
|
) -> TensorLikeType:
|
|
"""
|
|
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
|
|
either by zero-padding, or by slicing x starting from 0.
|
|
"""
|
|
assert len(dims) == len(sizes)
|
|
must_copy = False
|
|
x_sizes = x.shape
|
|
pad_amount = [0] * len(x_sizes) * 2
|
|
for i in range(len(dims)):
|
|
if sizes[i] == -1:
|
|
continue
|
|
|
|
if x_sizes[dims[i]] < sizes[i]:
|
|
must_copy = True
|
|
pad_idx = len(pad_amount) - 2 * dims[i] - 1
|
|
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
|
|
|
|
if x_sizes[dims[i]] > sizes[i]:
|
|
x = x.narrow(dims[i], 0, sizes[i])
|
|
|
|
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
|
|
|
|
|
|
def _fft_c2r(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any complex to real FFT (irfft or hfft)"""
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
|
torch._check(
|
|
last_dim_size >= 1,
|
|
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
|
|
|
if forward:
|
|
input = torch.conj(input)
|
|
|
|
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
|
|
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
|
|
|
|
|
|
def _fft_r2c(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
onesided: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
|
)
|
|
input = _maybe_promote_tensor_fft(input)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
dim_size = n if n is not None else input.shape[dim]
|
|
torch._check(
|
|
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims, (n,))
|
|
|
|
ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
|
|
ret = _apply_norm(ret, norm, dim_size, forward)
|
|
return ret if forward else torch.conj(ret)
|
|
|
|
|
|
def _fft_c2c(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
|
torch._check(
|
|
input.dtype.is_complex,
|
|
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
|
)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
dim_size = n if n is not None else input.shape[dim]
|
|
torch._check(
|
|
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims, (n,))
|
|
|
|
ret = prims.fft_c2c(input, dim=dims, forward=forward)
|
|
return _apply_norm(ret, norm, dim_size, forward)
|
|
|
|
|
|
@register_decomposition(aten.fft_fft)
|
|
@out_wrapper()
|
|
def fft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
if input.dtype.is_complex:
|
|
return _fft_c2c("fft", input, n, dim, norm, forward=True)
|
|
else:
|
|
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifft)
|
|
@out_wrapper()
|
|
def ifft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
if input.dtype.is_complex:
|
|
return _fft_c2c("ifft", input, n, dim, norm, forward=False)
|
|
else:
|
|
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfft)
|
|
@out_wrapper()
|
|
def rfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfft)
|
|
@out_wrapper()
|
|
def irfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_c2r("irfft", input, n, dim, norm, forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfft)
|
|
@out_wrapper()
|
|
def hfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_c2r("hfft", input, n, dim, norm, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfft)
|
|
@out_wrapper()
|
|
def ihfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
|
|
|
|
|
|
class _ShapeAndDims(NamedTuple):
|
|
shape: Tuple[int, ...]
|
|
dims: Tuple[int, ...]
|
|
|
|
|
|
def _canonicalize_fft_shape_and_dim_args(
|
|
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
|
|
) -> _ShapeAndDims:
|
|
"""Convert the shape and dim arguments into a canonical form where neither are optional"""
|
|
input_dim = input.ndim
|
|
input_sizes = input.shape
|
|
|
|
if dim is not None:
|
|
if not isinstance(dim, Sequence):
|
|
dim = (dim,)
|
|
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
|
|
|
# Check dims are unique
|
|
torch._check(
|
|
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
|
|
)
|
|
|
|
if shape is not None:
|
|
if not isinstance(shape, Sequence):
|
|
shape = (shape,)
|
|
|
|
# Has shape, might have dim
|
|
torch._check(
|
|
dim is None or len(dim) == len(shape),
|
|
lambda: "When given, dim and shape arguments must have the same length",
|
|
)
|
|
transform_ndim = len(shape)
|
|
|
|
torch._check(
|
|
transform_ndim <= input_dim,
|
|
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
|
f"only has {input_dim} dimensions.",
|
|
)
|
|
|
|
# If shape is given, dims defaults to the last len(shape) dimensions
|
|
if dim is None:
|
|
ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
|
|
|
|
# Translate any -1 values in shape to the default length
|
|
ret_shape = tuple(
|
|
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)
|
|
)
|
|
elif dim is None:
|
|
# No shape, no dim
|
|
ret_dims = tuple(range(input_dim))
|
|
ret_shape = tuple(input_sizes)
|
|
else:
|
|
# No shape, has dim
|
|
ret_shape = tuple(input_sizes[d] for d in ret_dims)
|
|
|
|
for n in ret_shape:
|
|
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
|
|
|
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
|
|
|
|
|
|
def _prod(xs: Iterable[int]) -> int:
|
|
"""Compute product of a list"""
|
|
prod = 1
|
|
for x in xs:
|
|
prod *= x
|
|
return prod
|
|
|
|
|
|
def _fftn_c2c(
|
|
function_name: str,
|
|
input: TensorLikeType,
|
|
shape: Tuple[int, ...],
|
|
dim: Tuple[int, ...],
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
|
torch._check(
|
|
input.dtype.is_complex,
|
|
lambda: f"{function_name} expects a complex input tensor, "
|
|
f"but got {input.dtype}",
|
|
)
|
|
x = _resize_fft_input(input, dim, shape)
|
|
output = prims.fft_c2c(x, dim=dim, forward=forward)
|
|
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
|
|
|
|
|
|
@register_decomposition(aten.fft_fftn)
|
|
@out_wrapper()
|
|
def fftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifftn)
|
|
@out_wrapper()
|
|
def ifftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfftn)
|
|
@out_wrapper()
|
|
def rfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
|
)
|
|
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
out = prims.fft_r2c(input, dim=dim, onesided=True)
|
|
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfftn)
|
|
@out_wrapper()
|
|
def ihfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
|
)
|
|
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
|
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
|
|
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
|
|
|
|
if len(dim) == 1:
|
|
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
|
|
return prims.conj(tmp)
|
|
|
|
tmp = prims.conj_physical(tmp)
|
|
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
|
|
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
|
|
|
|
|
|
class _CanonicalizeC2rReturn(NamedTuple):
|
|
shape: Tuple[int, ...]
|
|
dim: Tuple[int, ...]
|
|
last_dim_size: int
|
|
|
|
|
|
def _canonicalize_fft_c2r_shape_and_dim_args(
|
|
fname: str,
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType],
|
|
dim: Optional[DimsType],
|
|
) -> _CanonicalizeC2rReturn:
|
|
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
|
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
|
|
|
if s is None or s[-1] == -1:
|
|
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
|
else:
|
|
last_dim_size = shape[-1]
|
|
|
|
torch._check(
|
|
last_dim_size >= 1,
|
|
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
|
)
|
|
|
|
shape_list = list(shape)
|
|
shape_list[-1] = last_dim_size // 2 + 1
|
|
return _CanonicalizeC2rReturn(
|
|
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfftn)
|
|
@out_wrapper()
|
|
def irfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
|
"irfftn", input, s, dim
|
|
)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
|
|
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfftn)
|
|
@out_wrapper()
|
|
def hfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
|
"hfftn", input, s, dim
|
|
)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
|
|
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
|
|
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
|
|
tmp = prims.conj_physical(tmp)
|
|
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
|
|
return _apply_norm(out, norm, last_dim_size, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_fft2)
|
|
@out_wrapper()
|
|
def fft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifft2)
|
|
@out_wrapper()
|
|
def ifft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfft2)
|
|
@out_wrapper()
|
|
def rfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfft2)
|
|
@out_wrapper()
|
|
def irfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfft2)
|
|
@out_wrapper()
|
|
def hfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfft2)
|
|
@out_wrapper()
|
|
def ihfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
|
|
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
|
|
if dim is None:
|
|
return list(range(x.ndim))
|
|
elif not isinstance(dim, Sequence):
|
|
return [dim]
|
|
else:
|
|
return list(dim)
|
|
|
|
|
|
@register_decomposition(aten.fft_fftshift)
|
|
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
|
dims = _default_alldims(dim, input)
|
|
shift = [input.shape[d] // 2 for d in dims]
|
|
return torch.roll(input, shift, dims)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifftshift)
|
|
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
|
dims = _default_alldims(dim, input)
|
|
shift = [(input.shape[d] + 1) // 2 for d in dims]
|
|
return torch.roll(input, shift, dims)
|