mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Move complex<Half> from Half.h to complex.h (#140565)
Executing on old TODO on the way to sharing Half.h with ExecuTorch. Differential Revision: [D65888037](https://our.internmc.facebook.com/intern/diff/D65888037/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140565 Approved by: https://github.com/ezyang, https://github.com/malfet ghstack dependencies: #140564
This commit is contained in:
parent
f630799587
commit
e429a3b72e
|
|
@ -12,7 +12,6 @@
|
|||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
|
|
@ -385,56 +384,6 @@ struct alignas(2) Half {
|
|||
#endif
|
||||
};
|
||||
|
||||
// TODO : move to complex.h
|
||||
template <>
|
||||
struct alignas(4) complex<Half> {
|
||||
Half real_;
|
||||
Half imag_;
|
||||
|
||||
// Constructors
|
||||
complex() = default;
|
||||
// Half constructor is not constexpr so the following constructor can't
|
||||
// be constexpr
|
||||
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
||||
: real_(real), imag_(imag) {}
|
||||
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
||||
: real_(value.real()), imag_(value.imag()) {}
|
||||
|
||||
// Conversion operator
|
||||
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
||||
return {real_, imag_};
|
||||
}
|
||||
|
||||
constexpr C10_HOST_DEVICE Half real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr C10_HOST_DEVICE Half imag() const {
|
||||
return imag_;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
||||
auto a = static_cast<float>(real_);
|
||||
auto b = static_cast<float>(imag_);
|
||||
auto c = static_cast<float>(other.real());
|
||||
auto d = static_cast<float>(other.imag());
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <complex>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
#include <thrust/complex.h>
|
||||
|
|
@ -606,6 +607,55 @@ C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
|
|||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
struct alignas(4) complex<Half> {
|
||||
Half real_;
|
||||
Half imag_;
|
||||
|
||||
// Constructors
|
||||
complex() = default;
|
||||
// Half constructor is not constexpr so the following constructor can't
|
||||
// be constexpr
|
||||
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
||||
: real_(real), imag_(imag) {}
|
||||
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
||||
: real_(value.real()), imag_(value.imag()) {}
|
||||
|
||||
// Conversion operator
|
||||
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
||||
return {real_, imag_};
|
||||
}
|
||||
|
||||
constexpr C10_HOST_DEVICE Half real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr C10_HOST_DEVICE Half imag() const {
|
||||
return imag_;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
||||
auto a = static_cast<float>(real_);
|
||||
auto b = static_cast<float>(imag_);
|
||||
auto c = static_cast<float>(other.real());
|
||||
auto d = static_cast<float>(other.imag());
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/utils/byte_order.h>
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user