From 1e3600b5287346b29a835ad67f8b33945e0ec698 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 30 Oct 2025 17:52:53 -0500 Subject: [PATCH] [MPS] Move `logaddexp/logaddexp2` to Metal and support complex (#166670) NOTE: Complex inputs are only supported in `logaddexp`. Since `logaddexp2` does not support complex inputs for CPU, it is not enabled for MPS in this PR either. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166670 Approved by: https://github.com/malfet --- .../native/mps/kernels/BinaryKernel.metal | 28 +++++++++ .../native/mps/operations/BinaryKernel.mm | 10 ++++ .../ATen/native/mps/operations/BinaryOps.mm | 26 --------- aten/src/ATen/native/native_functions.yaml | 6 +- c10/metal/special_math.h | 58 +++++++++++++++++++ c10/metal/utils.h | 18 ++++++ torch/testing/_internal/common_mps.py | 2 + 7 files changed, 118 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 0764b9d5e12..5cb6dd38822 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -86,6 +86,28 @@ struct zeta_functor { } }; +struct logaddexp_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return c10::metal::logaddexp(a, b); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::logaddexp(float(a), float(b)); + } +}; + +struct logaddexp2_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return c10::metal::logaddexp2(a, b); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::logaddexp2(float(a), float(b)); + } +}; + struct xlog1py_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -377,6 +399,10 @@ REGISTER_FLOAT_BINARY_OP(fmin); REGISTER_FLOAT_BINARY_OP(nextafter); REGISTER_FLOAT_BINARY_OP(zeta); REGISTER_INT2FLOAT_BINARY_OP(zeta); +REGISTER_FLOAT_BINARY_OP(logaddexp); +REGISTER_INT2FLOAT_BINARY_OP(logaddexp); +REGISTER_FLOAT_BINARY_OP(logaddexp2); +REGISTER_INT2FLOAT_BINARY_OP(logaddexp2); REGISTER_FLOAT_BINARY_OP(xlog1py); REGISTER_INT2FLOAT_BINARY_OP(xlog1py); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t); @@ -463,6 +489,8 @@ REGISTER_BINARY_OP(add, float2, float2); REGISTER_BINARY_OP(add, half2, half2); REGISTER_BINARY_OP(sub, float2, float2); REGISTER_BINARY_OP(sub, half2, half2); +REGISTER_BINARY_OP(logaddexp, float2, float2); +REGISTER_BINARY_OP(logaddexp, half2, half2); REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2); REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2); REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 70211ceef07..f8baf2e7f11 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -89,6 +89,14 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "zeta"); } +static void logaddexp_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "logaddexp"); +} + +static void logaddexp2_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "logaddexp2"); +} + static void xlog1py_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types"); lib.exec_binary_kernel(iter, "xlog1py"); @@ -211,6 +219,8 @@ REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel) REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) +REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel); +REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel); REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index bffd7924326..d450a3ed8fe 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -17,8 +17,6 @@ #include #include #include -#include -#include #include #include #include @@ -277,30 +275,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const } } -TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = - [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] - name:nil]; - return [mpsGraph logarithmWithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block); -} - -TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = - [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] - name:nil]; - return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block); -} - TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e039241c796..ad3c75f2097 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3622,8 +3622,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp_out - MPS: logaddexp_out_mps + CPU, CUDA, MPS: logaddexp_out tags: pointwise - func: logaddexp(Tensor self, Tensor other) -> Tensor @@ -3635,8 +3634,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp2_out - MPS: logaddexp2_out_mps + CPU, CUDA, MPS: logaddexp2_out tags: pointwise - func: logaddexp2(Tensor self, Tensor other) -> Tensor diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index d80dfea9f03..defce910d7d 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -624,6 +624,64 @@ inline T spherical_bessel_j0(T x) { return static_cast(::metal::sin(x) / x); } +template +inline ::metal::enable_if_t, T> logaddexp( + T a, + T b) { + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0)))); + } +} + +// The function is ported from mlx +template +inline ::metal::enable_if_t, T> logaddexp(T a, T b) { + if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) || + ::metal::isnan(b.y)) { + return T(NAN, NAN); + } + + T maxval = a.x > b.x ? a : b; + T minval = a.x < b.x ? a : b; + constexpr auto inf = ::metal::numeric_limits::infinity().x; + + if (minval.x == -inf || maxval.x == inf) { + return maxval; + } + + float2 maxval_ = static_cast(maxval); + float2 minval_ = static_cast(minval); + float m = ::metal::exp(minval_.x - maxval_.x); + float2 dexp{ + m * ::metal::cos(minval_.y - maxval_.y), + m * ::metal::sin(minval_.y - maxval_.y), + }; + return static_cast(maxval_ + ::c10::metal::log1p(dexp)); +} + +template +inline T logaddexp2(T a, T b) { + constexpr auto log_2 = float(0.693147180559945309417232121458176); + constexpr auto inv_log_2 = float(1) / log_2; + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + + ::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) * + inv_log_2); + } +} + template inline float xlog1py(T x, T y) { if (::metal::isnan(y)) { diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 43d0eff27b8..51e04174e32 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -322,6 +322,24 @@ inline float log1p(float x) { return rc; } +// The function is ported from mlx +inline float2 log1p(float2 in) { + float x = in.x; + float y = in.y; + float zabs = ::metal::precise::sqrt(x * x + y * y); + float theta = ::metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y); + return {::metal::log(z0), theta}; + } +} + template struct pair { T1 first; diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 45ede2d5e43..b3289853192 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -92,6 +92,8 @@ if torch.backends.mps.is_available(): "log1p", "log2", "log", + "logaddexp", + "logaddexp2", "mH", "mT", "masked_fill",