mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D21021677: [pytorch][PR] Add core of c10::complex
Test Plan: revert-hammer Differential Revision: D21021677 Original commit changeset: 9e144e581fa4 fbshipit-source-id: ce6a88fc71ec0134d0fc6ecdddc4c4db35f89b1f
This commit is contained in:
parent
5150334c1d
commit
9216c67c9e
|
|
@ -8,7 +8,7 @@
|
|||
#include <complex>
|
||||
#include <type_traits>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/LegacyComplex.h>
|
||||
#include <c10/util/Complex.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at {
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ list(APPEND ATen_CPU_TEST_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp)
|
||||
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_stream_test.cpp
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
#include <c10/test/util/complex_test_common.h>
|
||||
|
||||
__global__ void test_thrust_kernel() {
|
||||
// thrust conversion
|
||||
{
|
||||
constexpr float num1 = float(1.23);
|
||||
constexpr float num2 = float(4.56);
|
||||
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).real() == num1);
|
||||
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).imag() == num2);
|
||||
}
|
||||
{
|
||||
constexpr double num1 = double(1.23);
|
||||
constexpr double num2 = double(4.56);
|
||||
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).real() == num1);
|
||||
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).imag() == num2);
|
||||
}
|
||||
// thrust assignment
|
||||
auto tup = assignment::one_two_thrust();
|
||||
assert(std::get<c10::complex<double>>(tup).real() == double(1));
|
||||
assert(std::get<c10::complex<double>>(tup).imag() == double(2));
|
||||
assert(std::get<c10::complex<float>>(tup).real() == float(1));
|
||||
assert(std::get<c10::complex<float>>(tup).imag() == float(2));
|
||||
}
|
||||
|
||||
__global__ void test_std_functions_kernel() {
|
||||
assert(std::abs(c10::complex<float>(3, 4)) == float(5));
|
||||
assert(std::abs(c10::complex<double>(3, 4)) == double(5));
|
||||
|
||||
assert(std::abs(std::arg(c10::complex<float>(0, 1)) - PI / 2) < 1e-6);
|
||||
assert(std::abs(std::arg(c10::complex<double>(0, 1)) - PI / 2) < 1e-6);
|
||||
|
||||
assert(std::abs(c10::polar(float(1), float(PI / 2)) - c10::complex<float>(0, 1)) < 1e-6);
|
||||
assert(std::abs(c10::polar(double(1), double(PI / 2)) - c10::complex<double>(0, 1)) < 1e-6);
|
||||
}
|
||||
|
||||
__global__ void test_reinterpret_cast() {
|
||||
std::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
assert(zz.real() == float(1));
|
||||
assert(zz.imag() == float(2));
|
||||
|
||||
std::complex<double> zzz(1, 2);
|
||||
c10::complex<double> zzzz = *reinterpret_cast<c10::complex<double>*>(&zzz);
|
||||
assert(zzzz.real() == double(1));
|
||||
assert(zzzz.imag() == double(2));
|
||||
}
|
||||
|
||||
TEST(DeviceTests, ThrustConversion) {
|
||||
cudaDeviceSynchronize();
|
||||
test_thrust_kernel<<<1, 1>>>();
|
||||
cudaDeviceSynchronize();
|
||||
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
|
||||
}
|
||||
|
||||
TEST(DeviceTests, StdFunctions) {
|
||||
cudaDeviceSynchronize();
|
||||
test_std_functions_kernel<<<1, 1>>>();
|
||||
cudaDeviceSynchronize();
|
||||
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
|
||||
}
|
||||
|
||||
TEST(DeviceTests, ReinterpretCast) {
|
||||
cudaDeviceSynchronize();
|
||||
test_reinterpret_cast<<<1, 1>>>();
|
||||
cudaDeviceSynchronize();
|
||||
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
|
||||
}
|
||||
|
||||
|
|
@ -51,9 +51,6 @@ fi
|
|||
if [[ -x ./cuda_tensor_interop_test ]]; then
|
||||
./cuda_tensor_interop_test
|
||||
fi
|
||||
if [[ -x ./cuda_complex_test ]]; then
|
||||
./cuda_complex_test
|
||||
fi
|
||||
if [ "$VALGRIND" == "ON" ]
|
||||
then
|
||||
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic --gtest_filter='-*CUDA'
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
#include <c10/test/util/complex_test_common.h>
|
||||
|
|
@ -1,464 +0,0 @@
|
|||
#include <type_traits>
|
||||
#include <tuple>
|
||||
#include <sstream>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#if (defined(__CUDACC__) || defined(__HIPCC__))
|
||||
#define MAYBE_GLOBAL __global__
|
||||
#else
|
||||
#define MAYBE_GLOBAL
|
||||
#endif
|
||||
|
||||
#define PI 3.141592653589793238463
|
||||
|
||||
namespace memory {
|
||||
|
||||
MAYBE_GLOBAL void test_size() {
|
||||
static_assert(sizeof(c10::complex<float>) == 2 * sizeof(float), "");
|
||||
static_assert(sizeof(c10::complex<double>) == 2 * sizeof(double), "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_align() {
|
||||
static_assert(alignof(c10::complex<float>) == 2 * sizeof(float), "");
|
||||
static_assert(alignof(c10::complex<double>) == 2 * sizeof(double), "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_pod() {
|
||||
static_assert(std::is_standard_layout<c10::complex<float>>::value, "");
|
||||
static_assert(std::is_standard_layout<c10::complex<double>>::value, "");
|
||||
}
|
||||
|
||||
TEST(TestMemory, ReinterpretCast) {
|
||||
std::complex<float> z(1, 2);
|
||||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
|
||||
ASSERT_EQ(zz.real(), float(1));
|
||||
ASSERT_EQ(zz.imag(), float(2));
|
||||
|
||||
std::complex<double> zzz(1, 2);
|
||||
c10::complex<double> zzzz = *reinterpret_cast<c10::complex<double>*>(&zzz);
|
||||
ASSERT_EQ(zzzz.real(), double(1));
|
||||
ASSERT_EQ(zzzz.imag(), double(2));
|
||||
}
|
||||
|
||||
} // memory
|
||||
|
||||
namespace constructors {
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_construct_from_scalar() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
constexpr scalar_t zero = scalar_t();
|
||||
static_assert(c10::complex<scalar_t>(num1, num2).real() == num1, "");
|
||||
static_assert(c10::complex<scalar_t>(num1, num2).imag() == num2, "");
|
||||
static_assert(c10::complex<scalar_t>(num1).real() == num1, "");
|
||||
static_assert(c10::complex<scalar_t>(num1).imag() == zero, "");
|
||||
static_assert(c10::complex<scalar_t>().real() == zero, "");
|
||||
static_assert(c10::complex<scalar_t>().imag() == zero, "");
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename other_t>
|
||||
C10_HOST_DEVICE void test_construct_from_other() {
|
||||
constexpr other_t num1 = other_t(1.23);
|
||||
constexpr other_t num2 = other_t(4.56);
|
||||
constexpr scalar_t num3 = scalar_t(num1);
|
||||
constexpr scalar_t num4 = scalar_t(num2);
|
||||
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3, "");
|
||||
static_assert(c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4, "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_convert_constructors() {
|
||||
test_construct_from_scalar<float>();
|
||||
test_construct_from_scalar<double>();
|
||||
|
||||
static_assert(std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
|
||||
static_assert(!std::is_convertible<c10::complex<double>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_convertible<c10::complex<float>, c10::complex<double>>::value, "");
|
||||
static_assert(std::is_convertible<c10::complex<double>, c10::complex<double>>::value, "");
|
||||
|
||||
static_assert(std::is_constructible<c10::complex<float>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<double>, c10::complex<float>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<float>, c10::complex<double>>::value, "");
|
||||
static_assert(std::is_constructible<c10::complex<double>, c10::complex<double>>::value, "");
|
||||
|
||||
test_construct_from_other<float, float>();
|
||||
test_construct_from_other<float, double>();
|
||||
test_construct_from_other<double, float>();
|
||||
test_construct_from_other<double, double>();
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_construct_from_std() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1, "");
|
||||
static_assert(c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2, "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_std_conversion() {
|
||||
test_construct_from_std<float>();
|
||||
test_construct_from_std<double>();
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template<typename scalar_t>
|
||||
void test_construct_from_thrust() {
|
||||
constexpr scalar_t num1 = scalar_t(1.23);
|
||||
constexpr scalar_t num2 = scalar_t(4.56);
|
||||
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(), num1);
|
||||
ASSERT_EQ(c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(), num2);
|
||||
}
|
||||
|
||||
TEST(TestConstructors, FromThrust) {
|
||||
test_construct_from_thrust<float>();
|
||||
test_construct_from_thrust<double>();
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
} // constructors
|
||||
|
||||
namespace assignment {
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> one() {
|
||||
c10::complex<scalar_t> result(3, 4);
|
||||
result = scalar_t(1);
|
||||
return result;
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_assign_real() {
|
||||
static_assert(one<float>().real() == float(1), "");
|
||||
static_assert(one<float>().imag() == float(), "");
|
||||
static_assert(one<double>().real() == double(1), "");
|
||||
static_assert(one<double>().imag() == double(), "");
|
||||
}
|
||||
|
||||
constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two() {
|
||||
constexpr c10::complex<float> src(1, 2);
|
||||
c10::complex<double> ret0;
|
||||
c10::complex<float> ret1;
|
||||
ret0 = ret1 = src;
|
||||
return std::make_tuple(ret0, ret1);
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_assign_other() {
|
||||
constexpr auto tup = one_two();
|
||||
static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
|
||||
static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
|
||||
static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
|
||||
static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
|
||||
}
|
||||
|
||||
constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two_std() {
|
||||
constexpr std::complex<float> src(1, 1);
|
||||
c10::complex<double> ret0;
|
||||
c10::complex<float> ret1;
|
||||
ret0 = ret1 = src;
|
||||
return std::make_tuple(ret0, ret1);
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_assign_std() {
|
||||
constexpr auto tup = one_two();
|
||||
static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
|
||||
static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
|
||||
static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
|
||||
static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>> one_two_thrust() {
|
||||
thrust::complex<float> src(1, 2);
|
||||
c10::complex<double> ret0;
|
||||
c10::complex<float> ret1;
|
||||
ret0 = ret1 = src;
|
||||
return std::make_tuple(ret0, ret1);
|
||||
}
|
||||
|
||||
TEST(TestAssignment, FromThrust) {
|
||||
auto tup = one_two_thrust();
|
||||
ASSERT_EQ(std::get<c10::complex<double>>(tup).real(), double(1));
|
||||
ASSERT_EQ(std::get<c10::complex<double>>(tup).imag(), double(2));
|
||||
ASSERT_EQ(std::get<c10::complex<float>>(tup).real(), float(1));
|
||||
ASSERT_EQ(std::get<c10::complex<float>>(tup).imag(), float(2));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace assignment
|
||||
|
||||
namespace literals {
|
||||
|
||||
MAYBE_GLOBAL void test_complex_literals() {
|
||||
using namespace c10::complex_literals;
|
||||
static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "");
|
||||
static_assert((0.5_if).real() == float(), "");
|
||||
static_assert((0.5_if).imag() == float(0.5), "");
|
||||
static_assert(std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
|
||||
static_assert((0.5_id).real() == float(), "");
|
||||
static_assert((0.5_id).imag() == float(0.5), "");
|
||||
|
||||
static_assert(std::is_same<decltype(1_if), c10::complex<float>>::value, "");
|
||||
static_assert((1_if).real() == float(), "");
|
||||
static_assert((1_if).imag() == float(1), "");
|
||||
static_assert(std::is_same<decltype(1_id), c10::complex<double>>::value, "");
|
||||
static_assert((1_id).real() == double(), "");
|
||||
static_assert((1_id).imag() == double(1), "");
|
||||
}
|
||||
|
||||
} // namespace literals
|
||||
|
||||
namespace real_imag {
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> zero_one() {
|
||||
c10::complex<scalar_t> result;
|
||||
result.imag(scalar_t(1));
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> one_zero() {
|
||||
c10::complex<scalar_t> result;
|
||||
result.real(scalar_t(1));
|
||||
return result;
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_real_imag_modify() {
|
||||
static_assert(zero_one<float>().real() == float(0), "");
|
||||
static_assert(zero_one<float>().imag() == float(1), "");
|
||||
static_assert(zero_one<double>().real() == double(0), "");
|
||||
static_assert(zero_one<double>().imag() == double(1), "");
|
||||
|
||||
static_assert(one_zero<float>().real() == float(1), "");
|
||||
static_assert(one_zero<float>().imag() == float(0), "");
|
||||
static_assert(one_zero<double>().real() == double(1), "");
|
||||
static_assert(one_zero<double>().imag() == double(0), "");
|
||||
}
|
||||
|
||||
} // namespace real_imag
|
||||
|
||||
namespace arithmetic_assign {
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> p(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result += value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> m(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result -= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> t(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result *= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
constexpr c10::complex<scalar_t> d(scalar_t value) {
|
||||
c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
|
||||
result /= value;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
|
||||
constexpr c10::complex<scalar_t> x = p(scalar_t(1));
|
||||
static_assert(x.real() == scalar_t(3), "");
|
||||
static_assert(x.imag() == scalar_t(2), "");
|
||||
constexpr c10::complex<scalar_t> y = m(scalar_t(1));
|
||||
static_assert(y.real() == scalar_t(1), "");
|
||||
static_assert(y.imag() == scalar_t(2), "");
|
||||
constexpr c10::complex<scalar_t> z = t(scalar_t(2));
|
||||
static_assert(z.real() == scalar_t(4), "");
|
||||
static_assert(z.imag() == scalar_t(4), "");
|
||||
constexpr c10::complex<scalar_t> t = d(scalar_t(2));
|
||||
static_assert(t.real() == scalar_t(1), "");
|
||||
static_assert(t.imag() == scalar_t(1), "");
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> p(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result += rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> m(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result -= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> t(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result *= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename rhs_t>
|
||||
constexpr c10::complex<scalar_t> d(scalar_t real, scalar_t imag, c10::complex<rhs_t> rhs) {
|
||||
c10::complex<scalar_t> result(real, imag);
|
||||
result /= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_assign_complex() {
|
||||
using namespace c10::complex_literals;
|
||||
constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if);
|
||||
static_assert(x2.real() == scalar_t(2), "");
|
||||
static_assert(x2.imag() == scalar_t(3), "");
|
||||
constexpr c10::complex<scalar_t> x3 = p(scalar_t(2), scalar_t(2), 1.0_id);
|
||||
static_assert(x3.real() == scalar_t(2), "");
|
||||
#if !defined(__CUDACC__)
|
||||
// The following is flaky on nvcc
|
||||
static_assert(x3.imag() == scalar_t(3), "");
|
||||
#endif
|
||||
|
||||
constexpr c10::complex<scalar_t> y2 = m(scalar_t(2), scalar_t(2), 1.0_if);
|
||||
static_assert(y2.real() == scalar_t(2), "");
|
||||
static_assert(y2.imag() == scalar_t(1), "");
|
||||
constexpr c10::complex<scalar_t> y3 = m(scalar_t(2), scalar_t(2), 1.0_id);
|
||||
static_assert(y3.real() == scalar_t(2), "");
|
||||
#if !defined(__CUDACC__)
|
||||
// The following is flaky on nvcc
|
||||
static_assert(y3.imag() == scalar_t(1), "");
|
||||
#endif
|
||||
|
||||
constexpr c10::complex<scalar_t> z2 = t(scalar_t(1), scalar_t(-2), 1.0_if);
|
||||
static_assert(z2.real() == scalar_t(2), "");
|
||||
static_assert(z2.imag() == scalar_t(1), "");
|
||||
constexpr c10::complex<scalar_t> z3 = t(scalar_t(1), scalar_t(-2), 1.0_id);
|
||||
static_assert(z3.real() == scalar_t(2), "");
|
||||
static_assert(z3.imag() == scalar_t(1), "");
|
||||
|
||||
constexpr c10::complex<scalar_t> t2 = d(scalar_t(-1), scalar_t(2), 1.0_if);
|
||||
static_assert(t2.real() == scalar_t(2), "");
|
||||
static_assert(t2.imag() == scalar_t(1), "");
|
||||
constexpr c10::complex<scalar_t> t3 = d(scalar_t(-1), scalar_t(2), 1.0_id);
|
||||
static_assert(t3.real() == scalar_t(2), "");
|
||||
static_assert(t3.imag() == scalar_t(1), "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_arithmetic_assign() {
|
||||
test_arithmetic_assign_scalar<float>();
|
||||
test_arithmetic_assign_scalar<double>();
|
||||
test_arithmetic_assign_complex<float>();
|
||||
test_arithmetic_assign_complex<double>();
|
||||
}
|
||||
|
||||
} // namespace arithmetic_assign
|
||||
|
||||
namespace arithmetic {
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_arithmetic_() {
|
||||
static_assert(c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(4, 6), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) + scalar_t(3) == c10::complex<scalar_t>(4, 2), "");
|
||||
static_assert(scalar_t(3) + c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(4, 2), "");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-2, -2), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) - scalar_t(3) == c10::complex<scalar_t>(-2, 2), "");
|
||||
static_assert(scalar_t(3) - c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(2, -2), "");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(-5, 10), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) * scalar_t(3) == c10::complex<scalar_t>(3, 6), "");
|
||||
static_assert(scalar_t(3) * c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(3, 6), "");
|
||||
|
||||
static_assert(c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(5, 10) / scalar_t(5) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(scalar_t(25) / c10::complex<scalar_t>(3, 4) == c10::complex<scalar_t>(3, -4), "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_arithmetic() {
|
||||
test_arithmetic_<float>();
|
||||
test_arithmetic_<double>();
|
||||
}
|
||||
|
||||
} // namespace arithmetic
|
||||
|
||||
namespace equality {
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_equality_() {
|
||||
static_assert(c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "");
|
||||
static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
|
||||
static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "");
|
||||
static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "");
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_equality() {
|
||||
test_equality_<float>();
|
||||
test_equality_<double>();
|
||||
}
|
||||
|
||||
} // namespace equality
|
||||
|
||||
namespace io {
|
||||
|
||||
template<typename scalar_t>
|
||||
void test_io_() {
|
||||
std::stringstream ss;
|
||||
c10::complex<scalar_t> a(1, 2);
|
||||
ss << a;
|
||||
ASSERT_EQ(ss.str(), "(1,2)");
|
||||
ss.str("(3,4)");
|
||||
ss >> a;
|
||||
ASSERT_TRUE(a == c10::complex<scalar_t>(3, 4));
|
||||
}
|
||||
|
||||
TEST(TestIO, All) {
|
||||
test_io_<float>();
|
||||
test_io_<double>();
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
|
||||
namespace test_std {
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_HOST_DEVICE void test_callable_() {
|
||||
static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "");
|
||||
static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "");
|
||||
std::abs(c10::complex<scalar_t>(1, 2));
|
||||
std::arg(c10::complex<scalar_t>(1, 2));
|
||||
static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "");
|
||||
static_assert(std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4), "");
|
||||
c10::polar(float(1), float(PI / 2));
|
||||
c10::polar(double(1), double(PI / 2));
|
||||
}
|
||||
|
||||
MAYBE_GLOBAL void test_callable() {
|
||||
test_callable_<float>();
|
||||
test_callable_<double>();
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void test_values_() {
|
||||
ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5));
|
||||
ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6);
|
||||
ASSERT_LT(std::abs(c10::polar(scalar_t(1), scalar_t(PI / 2)) - c10::complex<scalar_t>(0, 1)), 1e-6);
|
||||
}
|
||||
|
||||
TEST(TestStd, BasicFunctions) {
|
||||
test_values_<float>();
|
||||
test_values_<double>();
|
||||
}
|
||||
|
||||
} // namespace test_std
|
||||
|
|
@ -1,485 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
#include <thrust/complex.h>
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// c10::complex is an implementation of complex numbers that aims
|
||||
// to work on all devices supported by PyTorch
|
||||
//
|
||||
// Most of the APIs duplicates std::complex
|
||||
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
|
||||
//
|
||||
// [Note on Constructors]
|
||||
//
|
||||
// The APIs of constructors are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/complex
|
||||
//
|
||||
// Since C++14, all constructors are constexpr in std::complex
|
||||
//
|
||||
// There are three types of constructors:
|
||||
// - initializing from real and imag:
|
||||
// `constexpr complex( const T& re = T(), const T& im = T() );`
|
||||
// - implicitly-declared copy constructor
|
||||
// - converting constructors
|
||||
//
|
||||
// Converting constructors:
|
||||
// - std::complex defines converting constructor between float/double/long double,
|
||||
// while we define converting constructor between float/double.
|
||||
// - For these converting constructors, upcasting is implicit, downcasting is
|
||||
// explicit.
|
||||
// - We also define explicit casting from std::complex/thrust::complex
|
||||
// - Note that the conversion from thrust is not constexpr, because
|
||||
// thrust does not define them as constexpr ????
|
||||
//
|
||||
//
|
||||
// [Operator =]
|
||||
//
|
||||
// The APIs of operator = are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
|
||||
//
|
||||
// Since C++20, all operator= are constexpr. Although we are not building with
|
||||
// C++20, we also obey this behavior.
|
||||
//
|
||||
// There are three types of assign operator:
|
||||
// - Assign a real value from the same scalar type
|
||||
// - In std, this is templated as complex& operator=(const T& x)
|
||||
// with specialization `complex& operator=(T x)` for float/double/long double
|
||||
// Since we only support float and double, on will use `complex& operator=(T x)`
|
||||
// - Copy assignment operator and converting assignment operator
|
||||
// - There is no specialization of converting assignment operators, which type is
|
||||
// convertible is soly depend on whether the scalar type is convertable
|
||||
//
|
||||
// In addition to the standard assignment, we also provide assignment operators with std and thrust
|
||||
//
|
||||
//
|
||||
// [Casting operators]
|
||||
//
|
||||
// std::complex does not have casting operators. We define casting operators casting to std::complex and thrust::complex
|
||||
//
|
||||
//
|
||||
// [Operator ""]
|
||||
//
|
||||
// std::complex has custom literals `i`, `if` and `il` defined in namespace `std::literals::complex_literals`.
|
||||
// We define our own custom literals in the namespace `c10::complex_literals`. Our custom literals does not
|
||||
// follow the same behavior as in std::complex, instead, we define _if, _id to construct float/double
|
||||
// complex literals.
|
||||
//
|
||||
//
|
||||
// [real() and imag()]
|
||||
//
|
||||
// In C++20, there are two overload of these functions, one it to return the real/imag, another is to set real/imag,
|
||||
// they are both constexpr. We follow this design.
|
||||
//
|
||||
//
|
||||
// [Operator +=,-=,*=,/=]
|
||||
//
|
||||
// Since C++20, these operators become constexpr. In our implementation, they are also constexpr.
|
||||
//
|
||||
// There are two types of such operators: operating with a real number, or operating with another complex number.
|
||||
// For the operating with a real number, the generic template form has argument type `const T &`, while the overload
|
||||
// for float/double/long double has `T`. We will follow the same type as float/double/long double in std.
|
||||
//
|
||||
// [Unary operator +-]
|
||||
//
|
||||
// Since C++20, they are constexpr. We also make them expr
|
||||
//
|
||||
// [Binary operators +-*/]
|
||||
//
|
||||
// Each operator has three versions (taking + as example):
|
||||
// - complex + complex
|
||||
// - complex + real
|
||||
// - real + complex
|
||||
//
|
||||
// [Operator ==, !=]
|
||||
//
|
||||
// Each operator has three versions (taking == as example):
|
||||
// - complex == complex
|
||||
// - complex == real
|
||||
// - real == complex
|
||||
//
|
||||
// Some of them are removed on C++20, but we decide to keep them
|
||||
//
|
||||
// [Operator <<, >>]
|
||||
//
|
||||
// These are implemented by casting to std::complex
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported, because:
|
||||
// - lots of members and functions of c10::Half are not constexpr
|
||||
// - thrust::complex only support float and double
|
||||
|
||||
template<typename T>
|
||||
struct complex;
|
||||
|
||||
template<typename T>
|
||||
struct alignas(sizeof(T) * 2) complex_common {
|
||||
T storage[2];
|
||||
|
||||
constexpr complex_common(): storage{T(), T()} {}
|
||||
constexpr complex_common(const T& re, const T& im = T()): storage{re, im} {}
|
||||
template<typename U>
|
||||
explicit constexpr complex_common(const std::complex<U> &other): complex_common(other.real(), other.imag()) {}
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template<typename U>
|
||||
explicit C10_HOST_DEVICE complex_common(const thrust::complex<U> &other): complex_common(other.real(), other.imag()) {}
|
||||
#endif
|
||||
|
||||
constexpr complex<T> &operator =(T re) {
|
||||
storage[0] = re;
|
||||
storage[1] = 0;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
constexpr complex<T> &operator +=(T re) {
|
||||
storage[0] += re;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
constexpr complex<T> &operator -=(T re) {
|
||||
storage[0] -= re;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
constexpr complex<T> &operator *=(T re) {
|
||||
storage[0] *= re;
|
||||
storage[1] *= re;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
constexpr complex<T> &operator /=(T re) {
|
||||
storage[0] /= re;
|
||||
storage[1] /= re;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator =(const complex<U> &rhs) {
|
||||
storage[0] = rhs.real();
|
||||
storage[1] = rhs.imag();
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator +=(const complex<U> &rhs) {
|
||||
storage[0] += rhs.real();
|
||||
storage[1] += rhs.imag();
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator -=(const complex<U> &rhs) {
|
||||
storage[0] -= rhs.real();
|
||||
storage[1] -= rhs.imag();
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator *=(const complex<U> &rhs) {
|
||||
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
|
||||
T a = storage[0];
|
||||
T b = storage[1];
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
storage[0] = a * c - b * d;
|
||||
storage[1] = a * d + b * c;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator /=(const complex<U> &rhs) {
|
||||
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
|
||||
T a = storage[0];
|
||||
T b = storage[1];
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
auto denominator = c * c + d * d;
|
||||
storage[0] = (a * c + b * d) / denominator;
|
||||
storage[1] = (b * c - a * d) / denominator;
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
constexpr complex<T> &operator =(const std::complex<U> &rhs) {
|
||||
storage[0] = rhs.real();
|
||||
storage[1] = rhs.imag();
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template<typename U>
|
||||
C10_HOST_DEVICE complex<T> &operator =(const thrust::complex<U> &rhs) {
|
||||
storage[0] = rhs.real();
|
||||
storage[1] = rhs.imag();
|
||||
return static_cast<complex<T> &>(*this);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename U>
|
||||
explicit constexpr operator std::complex<U>() const {
|
||||
return std::complex<U>(std::complex<T>(real(), imag()));
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template<typename U>
|
||||
explicit operator thrust::complex<U>() const {
|
||||
return thrust::complex<U>(thrust::complex<T>(real(), imag()));
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr T real() const {
|
||||
return storage[0];
|
||||
}
|
||||
constexpr void real(T value) {
|
||||
storage[0] = value;
|
||||
}
|
||||
constexpr T imag() const {
|
||||
return storage[1];
|
||||
}
|
||||
constexpr void imag(T value) {
|
||||
storage[1] = value;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<>
|
||||
struct alignas(2*sizeof(float)) complex<float>: public complex_common<float> {
|
||||
using complex_common<float>::complex_common;
|
||||
constexpr complex(): complex_common() {}; // needed by CUDA 9.x
|
||||
explicit constexpr complex(const complex<double> &other);
|
||||
using complex_common<float>::operator=;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct alignas(2*sizeof(double)) complex<double>: public complex_common<double> {
|
||||
using complex_common<double>::complex_common;
|
||||
constexpr complex(): complex_common() {}; // needed by CUDA 9.x
|
||||
constexpr complex(const complex<float> &other);
|
||||
using complex_common<double>::operator=;
|
||||
};
|
||||
|
||||
constexpr complex<float>::complex(const complex<double> &other): complex_common(other.real(), other.imag()) {}
|
||||
constexpr complex<double>::complex(const complex<float> &other): complex_common(other.real(), other.imag()) {}
|
||||
|
||||
namespace complex_literals {
|
||||
|
||||
constexpr complex<float> operator"" _if(long double imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator"" _id(long double imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<float> operator"" _if(unsigned long long imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator"" _id(unsigned long long imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
} // namespace complex_literals
|
||||
|
||||
} // namespace c10
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator+(const c10::complex<T>& val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator-(const c10::complex<T>& val) {
|
||||
return c10::complex<T>(-val.real(), -val.imag());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator+(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator+(const c10::complex<T>& lhs, const T& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator+(const T& lhs, const c10::complex<T>& rhs) {
|
||||
return c10::complex<T>(lhs + rhs.real(), rhs.imag());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator-(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator-(const c10::complex<T>& lhs, const T& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator-(const T& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = -rhs;
|
||||
return result += lhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator*(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator*(const c10::complex<T>& lhs, const T& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator*(const T& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = rhs;
|
||||
return result *= lhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator/(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator/(const c10::complex<T>& lhs, const T& rhs) {
|
||||
c10::complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> operator/(const T& lhs, const c10::complex<T>& rhs) {
|
||||
c10::complex<T> result(lhs, T());
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator==(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator==(const c10::complex<T>& lhs, const T& rhs) {
|
||||
return (lhs.real() == rhs) && (lhs.imag() == T());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator==(const T& lhs, const c10::complex<T>& rhs) {
|
||||
return (lhs == rhs.real()) && (T() == rhs.imag());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator!=(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator!=(const c10::complex<T>& lhs, const T& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr bool operator!=(const T& lhs, const c10::complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os, const c10::complex<T>& x) {
|
||||
return (os << static_cast<std::complex<T>>(x));
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_istream<CharT, Traits>& operator>>(std::basic_istream<CharT, Traits>& is, c10::complex<T>& x) {
|
||||
std::complex<T> tmp;
|
||||
is >> tmp;
|
||||
x = tmp;
|
||||
return is;
|
||||
}
|
||||
|
||||
// std functions
|
||||
//
|
||||
// The implementation of these functions also follow the design of C++20
|
||||
|
||||
namespace std {
|
||||
|
||||
template<typename T>
|
||||
constexpr T real(const c10::complex<T>& z) {
|
||||
return z.real();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T imag(const c10::complex<T>& z) {
|
||||
return z.imag();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
|
||||
return std::hypot(std::real(z), std::imag(z));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
|
||||
return std::atan2(std::imag(z), std::real(z));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T norm(const c10::complex<T>& z) {
|
||||
return z.real() * z.real() + z.imag() * z.imag();
|
||||
}
|
||||
|
||||
// For std::conj, there are other versions of it:
|
||||
// constexpr std::complex<float> conj( float z );
|
||||
// template< class DoubleOrInteger >
|
||||
// constexpr std::complex<double> conj( DoubleOrInteger z );
|
||||
// constexpr std::complex<long double> conj( long double z );
|
||||
// These are not implemented
|
||||
// TODO(@zasdfgbnm): implement them as c10::conj
|
||||
template<typename T>
|
||||
constexpr c10::complex<T> conj(const c10::complex<T>& z) {
|
||||
return c10::complex<T>(z.real(), -z.imag());
|
||||
}
|
||||
|
||||
// Thrust does not have complex --> complex version of thrust::proj,
|
||||
// so this function is not implemented at c10 right now.
|
||||
// TODO(@zasdfgbnm): implement it by ourselves
|
||||
|
||||
// There is no c10 version of std::polar, because std::polar always
|
||||
// returns std::complex. Use c10::polar instead;
|
||||
|
||||
} // namespace std
|
||||
|
||||
namespace c10 {
|
||||
|
||||
template<typename T>
|
||||
C10_HOST_DEVICE c10::complex<T> polar(const T& r, const T& theta = T()) {
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
return static_cast<c10::complex<T>>(thrust::polar(r, theta));
|
||||
#else
|
||||
return static_cast<c10::complex<T>>(std::polar(r, theta));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
// math functions are included in a separate file
|
||||
#include <c10/util/complex_math.h>
|
||||
|
|
@ -1 +0,0 @@
|
|||
// coming soon...
|
||||
Loading…
Reference in New Issue
Block a user