Move complex<Half> from Half.h to complex.h (#140565)

Executing on old TODO on the way to sharing Half.h with ExecuTorch.

Differential Revision: [D65888037](https://our.internmc.facebook.com/intern/diff/D65888037/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140565
Approved by: https://github.com/ezyang, https://github.com/malfet
ghstack dependencies: #140564
This commit is contained in:
Scott Wolchok 2024-11-14 10:33:57 -08:00 committed by PyTorch MergeBot
parent f630799587
commit e429a3b72e
3 changed files with 51 additions and 51 deletions

View File

@ -12,7 +12,6 @@
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <c10/util/complex.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
@ -385,56 +384,6 @@ struct alignas(2) Half {
#endif
};
// TODO : move to complex.h
template <>
struct alignas(4) complex<Half> {
Half real_;
Half imag_;
// Constructors
complex() = default;
// Half constructor is not constexpr so the following constructor can't
// be constexpr
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
: real_(real), imag_(imag) {}
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
: real_(value.real()), imag_(value.imag()) {}
// Conversion operator
inline C10_HOST_DEVICE operator c10::complex<float>() const {
return {real_, imag_};
}
constexpr C10_HOST_DEVICE Half real() const {
return real_;
}
constexpr C10_HOST_DEVICE Half imag() const {
return imag_;
}
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
auto a = static_cast<float>(real_);
auto b = static_cast<float>(imag_);
auto c = static_cast<float>(other.real());
auto d = static_cast<float>(other.imag());
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
};
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
out << (float)value;
return out;

View File

@ -3,6 +3,7 @@
#include <complex>
#include <c10/macros/Macros.h>
#include <c10/util/Half.h>
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <thrust/complex.h>
@ -606,6 +607,55 @@ C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
#endif
}
template <>
struct alignas(4) complex<Half> {
Half real_;
Half imag_;
// Constructors
complex() = default;
// Half constructor is not constexpr so the following constructor can't
// be constexpr
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
: real_(real), imag_(imag) {}
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
: real_(value.real()), imag_(value.imag()) {}
// Conversion operator
inline C10_HOST_DEVICE operator c10::complex<float>() const {
return {real_, imag_};
}
constexpr C10_HOST_DEVICE Half real() const {
return real_;
}
constexpr C10_HOST_DEVICE Half imag() const {
return imag_;
}
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
auto a = static_cast<float>(real_);
auto b = static_cast<float>(imag_);
auto c = static_cast<float>(other.real());
auto d = static_cast<float>(other.imag());
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
};
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -1,4 +1,5 @@
#include <c10/util/BFloat16.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <torch/csrc/utils/byte_order.h>