mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move complex<Half> from Half.h to complex.h (#140565)
Executing on old TODO on the way to sharing Half.h with ExecuTorch. Differential Revision: [D65888037](https://our.internmc.facebook.com/intern/diff/D65888037/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140565 Approved by: https://github.com/ezyang, https://github.com/malfet ghstack dependencies: #140564
This commit is contained in:
parent
f630799587
commit
e429a3b72e
|
|
@ -12,7 +12,6 @@
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/bit_cast.h>
|
#include <c10/util/bit_cast.h>
|
||||||
#include <c10/util/complex.h>
|
|
||||||
#include <c10/util/floating_point_utils.h>
|
#include <c10/util/floating_point_utils.h>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
|
@ -385,56 +384,6 @@ struct alignas(2) Half {
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO : move to complex.h
|
|
||||||
template <>
|
|
||||||
struct alignas(4) complex<Half> {
|
|
||||||
Half real_;
|
|
||||||
Half imag_;
|
|
||||||
|
|
||||||
// Constructors
|
|
||||||
complex() = default;
|
|
||||||
// Half constructor is not constexpr so the following constructor can't
|
|
||||||
// be constexpr
|
|
||||||
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
|
||||||
: real_(real), imag_(imag) {}
|
|
||||||
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
|
||||||
: real_(value.real()), imag_(value.imag()) {}
|
|
||||||
|
|
||||||
// Conversion operator
|
|
||||||
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
|
||||||
return {real_, imag_};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr C10_HOST_DEVICE Half real() const {
|
|
||||||
return real_;
|
|
||||||
}
|
|
||||||
constexpr C10_HOST_DEVICE Half imag() const {
|
|
||||||
return imag_;
|
|
||||||
}
|
|
||||||
|
|
||||||
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
|
||||||
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
|
||||||
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
|
||||||
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
|
||||||
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
|
||||||
auto a = static_cast<float>(real_);
|
|
||||||
auto b = static_cast<float>(imag_);
|
|
||||||
auto c = static_cast<float>(other.real());
|
|
||||||
auto d = static_cast<float>(other.imag());
|
|
||||||
real_ = a * c - b * d;
|
|
||||||
imag_ = a * d + b * c;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
|
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
|
||||||
out << (float)value;
|
out << (float)value;
|
||||||
return out;
|
return out;
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
|
||||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||||
#include <thrust/complex.h>
|
#include <thrust/complex.h>
|
||||||
|
|
@ -606,6 +607,55 @@ C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct alignas(4) complex<Half> {
|
||||||
|
Half real_;
|
||||||
|
Half imag_;
|
||||||
|
|
||||||
|
// Constructors
|
||||||
|
complex() = default;
|
||||||
|
// Half constructor is not constexpr so the following constructor can't
|
||||||
|
// be constexpr
|
||||||
|
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
||||||
|
: real_(real), imag_(imag) {}
|
||||||
|
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
||||||
|
: real_(value.real()), imag_(value.imag()) {}
|
||||||
|
|
||||||
|
// Conversion operator
|
||||||
|
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
||||||
|
return {real_, imag_};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr C10_HOST_DEVICE Half real() const {
|
||||||
|
return real_;
|
||||||
|
}
|
||||||
|
constexpr C10_HOST_DEVICE Half imag() const {
|
||||||
|
return imag_;
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
||||||
|
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
||||||
|
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
||||||
|
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
||||||
|
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
||||||
|
auto a = static_cast<float>(real_);
|
||||||
|
auto b = static_cast<float>(imag_);
|
||||||
|
auto c = static_cast<float>(other.real());
|
||||||
|
auto d = static_cast<float>(other.imag());
|
||||||
|
real_ = a * c - b * d;
|
||||||
|
imag_ = a * d + b * c;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
C10_CLANG_DIAGNOSTIC_POP()
|
C10_CLANG_DIAGNOSTIC_POP()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#include <c10/util/BFloat16.h>
|
#include <c10/util/BFloat16.h>
|
||||||
|
#include <c10/util/complex.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <torch/csrc/utils/byte_order.h>
|
#include <torch/csrc/utils/byte_order.h>
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user