mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
`half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}`
* We only add support for `CUDA` as `cuFFT` supports these precision.
* We still error out on `CPU` and `ROCm` as their respective backends don't support this precision
For `cuFFT` following are the constraints for these precisions
* Minimum GPU architecture is SM_53
* Sizes are restricted to powers of two only
* Strides on the real part of real-to-complex and complex-to-real transforms are not supported
* More than one GPU is not supported
* Transforms spanning more than 4 billion elements are not supported
Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms
TODO:
* [x] Update docs about the restrictions
* [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 )
Ref for second point in TODO:
|
||
|---|---|---|
| .. | ||
| __init__.py | ||