mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
RFC: https://github.com/pytorch/rfcs/pull/54 First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/ We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core. In the next commits, I do a number of things in this order - Fix a few small issues - Make the tests that this PR adds pass - Bend backwards until lintrunner passes - Remove the optional dependency on `torch_np` and simply rely on the upstreamed code - Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now. Missing from this PR (but not blocking): - Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate. - https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge. All the tests in `tests/torch_np` take about 75s to run. This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211 Approved by: https://github.com/ezyang
129 lines
2.7 KiB
Python
129 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
|
|
import torch
|
|
|
|
from . import _dtypes_impl, _util
|
|
from ._normalizations import ArrayLike, normalizer
|
|
|
|
|
|
def upcast(func):
|
|
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
|
|
|
|
@functools.wraps(func)
|
|
def wrapped(tensor, *args, **kwds):
|
|
target_dtype = (
|
|
_dtypes_impl.default_dtypes().complex_dtype
|
|
if tensor.is_complex()
|
|
else _dtypes_impl.default_dtypes().float_dtype
|
|
)
|
|
tensor = _util.cast_if_needed(tensor, target_dtype)
|
|
return func(tensor, *args, **kwds)
|
|
|
|
return wrapped
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.fft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.ifft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.rfft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.irfft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
|
|
return torch.fft.fftn(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
|
|
return torch.fft.ifftn(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
|
|
return torch.fft.rfftn(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
|
|
return torch.fft.irfftn(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
|
|
return torch.fft.fft2(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
|
|
return torch.fft.ifft2(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
|
|
return torch.fft.rfft2(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
|
|
return torch.fft.irfft2(a, s, dim=axes, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.hfft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
@upcast
|
|
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
|
|
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
|
|
|
|
|
|
@normalizer
|
|
def fftfreq(n, d=1.0):
|
|
return torch.fft.fftfreq(n, d)
|
|
|
|
|
|
@normalizer
|
|
def rfftfreq(n, d=1.0):
|
|
return torch.fft.rfftfreq(n, d)
|
|
|
|
|
|
@normalizer
|
|
def fftshift(x: ArrayLike, axes=None):
|
|
return torch.fft.fftshift(x, axes)
|
|
|
|
|
|
@normalizer
|
|
def ifftshift(x: ArrayLike, axes=None):
|
|
return torch.fft.ifftshift(x, axes)
|