[mps] Hoist erfinv logic out of the kernel in preparation for moving. (#145568)

Will be used in inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145568
Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
Davide Italiano 2025-01-24 18:51:09 +00:00 committed by PyTorch MergeBot
parent 8eea554332
commit 6cda572c98

View File

@ -3,24 +3,19 @@
using namespace c10::metal;
using namespace metal;
constant float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
constant float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
constant float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
constant float d[2] = {3.543889200, 1.637067800};
template <typename T0, typename T1>
kernel void erfinv_kernel(
device T0* output [[buffer(0)]],
constant T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
float y = input[index];
float x, z, num, dem; /*working variables */
template <typename T>
float erfinv(T y) {
/* coefficients in rational expansion */
constexpr float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331};
constexpr float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801};
constexpr float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311};
constexpr float d[2] = {3.543889200, 1.637067800};
float y_abs = abs(y);
float x, z, num, dem; /*working variables */
float y_abs = abs(static_cast<float>(y));
if (y_abs >= 1.0f) {
output[index] = T0(y_abs > 1.0f ? NAN : copysign(INFINITY, y));
return;
return y_abs > 1.0f ? NAN : copysign(INFINITY, static_cast<float>(y));
}
if (y_abs <= 0.7f) {
z = y * y;
@ -31,10 +26,18 @@ kernel void erfinv_kernel(
z = sqrt(-1.0f * log((1.0 - y_abs) / 2.0));
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
dem = (d[1] * z + d[0]) * z + 1.0f;
x = copysign(num, y) / dem;
x = copysign(num, static_cast<float>(y)) / dem;
}
output[index] = T0(x);
return x;
}
template <typename T0, typename T1>
kernel void erfinv_kernel(
device T0* output [[buffer(0)]],
constant T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T0(erfinv(input[index]));
}
template <typename T0, typename T1>