Commit Graph

11 Commits

Author SHA1 Message Date
yhl48
6fcd671574 Complex support for expm1 (#96644)
Fixes #92619

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96644
Approved by: https://github.com/soulitzer
2023-03-24 17:24:50 +00:00
mfkasim1
1588ea0dbf Added log1p for complex in c10 (#89214)
One PR towards #89205.
The content is mostly from PR #38465, but slightly changed the expression to make it faster.

Here are some benchmarking code:
```c++
#include <complex>
#include <iostream>
#include <chrono>

// main.cc

template<typename T> inline std::complex<T> log1p_v0(const std::complex<T> &z) {
    // this PR
    T x = z.real();
    T y = z.imag();
    T theta = std::atan2(y, x + T(1));
    T r = x * (x + T(2)) + y * y;
    return {T(0.5) * std::log1p(r), theta};
}

template<typename T> inline std::complex<T> log1p_v1(const std::complex<T> &z) {
    // PR #38465
    T x = z.real();
    T y = z.imag();
    std::complex<T> p1 = z + T(1);
    T r = std::abs(p1);
    T a = std::arg(p1);
    T rm1 = (x * x + y * y + x * T(2)) / (r + 1);
    return {std::log1p(rm1), a};
}

template<typename T>
inline std::complex<T> log1p_v2(const std::complex<T> &z) {
    // naive, but numerically inaccurate
    return std::log(T(1) + z);
}

int main() {
    int n = 1000000;
    std::complex<float> res(0.0, 0.0);
    std::complex<float> input(0.5, 2.0);
    auto start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v0(input);
    }
    auto end = std::chrono::system_clock::now();
    auto elapsed = end - start;
    std::cout << "time for v0: " << elapsed.count() << '\n';

    start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v1(input);
    }
    end = std::chrono::system_clock::now();
    elapsed = end - start;
    std::cout << "time for v1: " << elapsed.count() << '\n';

    start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v2(input);
    }
    end = std::chrono::system_clock::now();
    elapsed = end - start;
    std::cout << "time for v2: " << elapsed.count() << '\n';
    std::cout << res << '\n';
}
```

Compiling the script with command `g++ main.cc` produces the following results:
```
time for v0: 237812271
time for v1: 414524941
time for v2: 360585994
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89214
Approved by: https://github.com/lezcano
2022-11-24 11:11:51 +00:00
Scott Wolchok
44cc873fba [PyTorch] Autoformat c10 (#56830)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56830

Opt into formatting on GitHub and format everything. This is a trial run before turning on formatting for more and eventually all of the codebase.

Test Plan: CI

Reviewed By: zertosh

Differential Revision: D27979080

fbshipit-source-id: a80f0c48691c08ae8ca0af06377b87e6a2351151
2021-04-30 21:23:28 -07:00
Xiang Gao
c7d79f35e3 Header rename complex_type.h -> complex.h (#39885)
Summary:
This file should have been renamed as `complex.h`, but unfortunately, it was named as `complex_type.h` due to a name clash with FBCode. Is this still the case and is it easy to resolve the name clash? Maybe related to the comment at https://github.com/pytorch/pytorch/pull/39834#issuecomment-642950012
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39885

Differential Revision: D22018575

Pulled By: ezyang

fbshipit-source-id: e237ccedbe2b30c31aca028a5b4c8c063087a30f
2020-06-23 16:27:09 -07:00
Gao, Xiang
dea58a7660 [resubmit] Kill thrust::complex from log kernels (#40079)
Summary:
Use `::log` instead of `std::log` for better ROCm support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40079

Differential Revision: D22068554

Pulled By: pbelevich

fbshipit-source-id: a458ae34535a641832f816617387a45445e2fa48
2020-06-17 05:57:10 -07:00
Xiang Gao
eb358f49c2 Overload complex math functions on both :: and std:: (#39829)
Summary:
Because ROCm has bug on std:: functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39829

Differential Revision: D22018430

Pulled By: anjali411

fbshipit-source-id: 671e158d3e3342394d1deaebd7ff011cce94c31a
2020-06-15 16:53:16 -07:00
Pavel Belevich
cf64af1ad2 Revert D22036002: [pytorch][PR] Kill thrust::complex from log kernels
Test Plan: revert-hammer

Differential Revision:
D22036002

Original commit changeset: 8852a833a0c7

fbshipit-source-id: 36d3c8d0e489f8a11a6e3e9d1ae162c192748037
2020-06-14 15:30:48 -07:00
Xiang Gao
4947ee3811 Kill thrust::complex from log kernels (#39902)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39902

Differential Revision: D22036002

Pulled By: pbelevich

fbshipit-source-id: 8852a833a0c71343ae630754f00da35a66e05917
2020-06-14 11:44:28 -07:00
Gao, Xiang
c5624e831d Add overloads of std:: math functions for c10::complex [resubmit] (#37468)
Summary:
This reverts commit d167a7f654.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37468

Differential Revision: D21305110

Pulled By: anjali411

fbshipit-source-id: d1bdc9d9feac00331fc2b2b905d49f80bef680f9
2020-04-30 10:20:45 -07:00
Lu Fang
d167a7f654 Revert D21256854: [pytorch][PR] Add overloads of std:: math functions for c10::complex
Test Plan: revert-hammer

Differential Revision:
D21256854

Original commit changeset: 2112ba6b7992

fbshipit-source-id: b81c377f9cd33a493a63d1e666cbe6765516fca8
2020-04-27 13:23:34 -07:00
Gao, Xiang
6d409481b3 Add overloads of std:: math functions for c10::complex (#35725)
Summary:
Issue: https://github.com/pytorch/pytorch/issues/35284

~This depends on and contains https://github.com/pytorch/pytorch/pull/35524. Please review after the dependency gets merged and I will rebase to get a clean diff.~

The implementation of most functions follow the pattern

```C++
template<typename T>
C10_HOST_DEVICE c10::complex<T> some_function(c10::complex<T> x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
  return static_cast<c10::complex<T>>(thrust::some_function(static_cast<thrust::complex<T>>(x)));
#else
  return static_cast<c10::complex<T>>(std::some_function(static_cast<std::complex<T>>(x)));
#endif
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35725

Differential Revision: D21256854

Pulled By: ezyang

fbshipit-source-id: 2112ba6b79923450feafd7ebdc7184a3eaecadb6
2020-04-27 10:32:16 -07:00