mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[mps] Hoist erfinv logic out of the kernel in preparation for moving. (#145568)
Will be used in inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145568 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
parent
8eea554332
commit
6cda572c98
|
|
@ -3,24 +3,19 @@
|
|||
using namespace c10::metal;
|
||||
using namespace metal;
|
||||
|
||||
constant float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
|
||||
constant float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
|
||||
constant float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
|
||||
constant float d[2] = {3.543889200, 1.637067800};
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void erfinv_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
constant T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
float y = input[index];
|
||||
float x, z, num, dem; /*working variables */
|
||||
template <typename T>
|
||||
float erfinv(T y) {
|
||||
/* coefficients in rational expansion */
|
||||
constexpr float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
|
||||
constexpr float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
|
||||
constexpr float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
|
||||
constexpr float d[2] = {3.543889200, 1.637067800};
|
||||
|
||||
float y_abs = abs(y);
|
||||
float x, z, num, dem; /*working variables */
|
||||
|
||||
float y_abs = abs(static_cast<float>(y));
|
||||
if (y_abs >= 1.0f) {
|
||||
output[index] = T0(y_abs > 1.0f ? NAN : copysign(INFINITY, y));
|
||||
return;
|
||||
return y_abs > 1.0f ? NAN : copysign(INFINITY, static_cast<float>(y));
|
||||
}
|
||||
if (y_abs <= 0.7f) {
|
||||
z = y * y;
|
||||
|
|
@ -31,10 +26,18 @@ kernel void erfinv_kernel(
|
|||
z = sqrt(-1.0f * log((1.0 - y_abs) / 2.0));
|
||||
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
|
||||
dem = (d[1] * z + d[0]) * z + 1.0f;
|
||||
x = copysign(num, y) / dem;
|
||||
x = copysign(num, static_cast<float>(y)) / dem;
|
||||
}
|
||||
|
||||
output[index] = T0(x);
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void erfinv_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
constant T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T0(erfinv(input[index]));
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user