mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex) Changes so far: - [x] Added complex support of torch.empty and torch.fill() - [x] Added complex support of CopyKernels - The 'static_cast_with_inter_type' template function is specialized for the following cases - `dest_t = thrust::complex<dest_value_t>`, `src_t = std::complex<src_value_t>` - `dest_t = std::complex<dest_value_t>`, `src_t = thrust::complex<src_value_t>` - This handles the compile-time case where `dest_value_t=double` and `src_value_t=float`. - [x] Added complex support of BinaryOp kernels - `using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;` converts std::complex<T> ScalarTypes to thrust types and is a no-op of other Scalar Types. - The operator is performed using complex number support defined in `thrust/complex.h` - This could be extended to work with ROCm by using `rocm/complex.h` - [x] Added complex support of UnaryOp kernels - Added CUDA support for `angle()`, `real()`, `imag()`, `conj()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/30295 Differential Revision: D18781954 Pulled By: ezyang fbshipit-source-id: 25d204c0b8143ee27fda345a5d6a82f095da92a7 |
||
|---|---|---|
| .. | ||
| core | ||
| cuda | ||
| hip | ||
| macros | ||
| test | ||
| util | ||
| CMakeLists.txt | ||