mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Move logaddexp/logaddexp2 to Metal and support complex (#166670)
NOTE: Complex inputs are only supported in `logaddexp`. Since `logaddexp2` does not support complex inputs for CPU, it is not enabled for MPS in this PR either. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166670 Approved by: https://github.com/malfet
This commit is contained in:
parent
fee7624bd6
commit
1e3600b528
|
|
@ -86,6 +86,28 @@ struct zeta_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct logaddexp_functor {
|
||||||
|
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||||
|
inline T operator()(const T a, const T b) {
|
||||||
|
return c10::metal::logaddexp(a, b);
|
||||||
|
}
|
||||||
|
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
|
||||||
|
inline float operator()(const T a, const T b) {
|
||||||
|
return c10::metal::logaddexp(float(a), float(b));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct logaddexp2_functor {
|
||||||
|
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||||
|
inline T operator()(const T a, const T b) {
|
||||||
|
return c10::metal::logaddexp2(a, b);
|
||||||
|
}
|
||||||
|
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
|
||||||
|
inline float operator()(const T a, const T b) {
|
||||||
|
return c10::metal::logaddexp2(float(a), float(b));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct xlog1py_functor {
|
struct xlog1py_functor {
|
||||||
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
|
||||||
inline T operator()(const T a, const T b) {
|
inline T operator()(const T a, const T b) {
|
||||||
|
|
@ -377,6 +399,10 @@ REGISTER_FLOAT_BINARY_OP(fmin);
|
||||||
REGISTER_FLOAT_BINARY_OP(nextafter);
|
REGISTER_FLOAT_BINARY_OP(nextafter);
|
||||||
REGISTER_FLOAT_BINARY_OP(zeta);
|
REGISTER_FLOAT_BINARY_OP(zeta);
|
||||||
REGISTER_INT2FLOAT_BINARY_OP(zeta);
|
REGISTER_INT2FLOAT_BINARY_OP(zeta);
|
||||||
|
REGISTER_FLOAT_BINARY_OP(logaddexp);
|
||||||
|
REGISTER_INT2FLOAT_BINARY_OP(logaddexp);
|
||||||
|
REGISTER_FLOAT_BINARY_OP(logaddexp2);
|
||||||
|
REGISTER_INT2FLOAT_BINARY_OP(logaddexp2);
|
||||||
REGISTER_FLOAT_BINARY_OP(xlog1py);
|
REGISTER_FLOAT_BINARY_OP(xlog1py);
|
||||||
REGISTER_INT2FLOAT_BINARY_OP(xlog1py);
|
REGISTER_INT2FLOAT_BINARY_OP(xlog1py);
|
||||||
REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t);
|
REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t);
|
||||||
|
|
@ -463,6 +489,8 @@ REGISTER_BINARY_OP(add, float2, float2);
|
||||||
REGISTER_BINARY_OP(add, half2, half2);
|
REGISTER_BINARY_OP(add, half2, half2);
|
||||||
REGISTER_BINARY_OP(sub, float2, float2);
|
REGISTER_BINARY_OP(sub, float2, float2);
|
||||||
REGISTER_BINARY_OP(sub, half2, half2);
|
REGISTER_BINARY_OP(sub, half2, half2);
|
||||||
|
REGISTER_BINARY_OP(logaddexp, float2, float2);
|
||||||
|
REGISTER_BINARY_OP(logaddexp, half2, half2);
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2);
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2);
|
||||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2);
|
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2);
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,14 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
|
||||||
lib.exec_binary_kernel(iter, "zeta");
|
lib.exec_binary_kernel(iter, "zeta");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void logaddexp_mps_kernel(TensorIteratorBase& iter) {
|
||||||
|
lib.exec_binary_kernel(iter, "logaddexp");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void logaddexp2_mps_kernel(TensorIteratorBase& iter) {
|
||||||
|
lib.exec_binary_kernel(iter, "logaddexp2");
|
||||||
|
}
|
||||||
|
|
||||||
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
|
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
|
||||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
|
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
|
||||||
lib.exec_binary_kernel(iter, "xlog1py");
|
lib.exec_binary_kernel(iter, "xlog1py");
|
||||||
|
|
@ -211,6 +219,8 @@ REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||||
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
|
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
|
||||||
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
|
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
|
||||||
|
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel);
|
||||||
|
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel);
|
||||||
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel)
|
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel)
|
||||||
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel)
|
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel)
|
||||||
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel)
|
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel)
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,6 @@
|
||||||
#include <ATen/ops/ge_native.h>
|
#include <ATen/ops/ge_native.h>
|
||||||
#include <ATen/ops/gt_native.h>
|
#include <ATen/ops/gt_native.h>
|
||||||
#include <ATen/ops/le_native.h>
|
#include <ATen/ops/le_native.h>
|
||||||
#include <ATen/ops/logaddexp2_native.h>
|
|
||||||
#include <ATen/ops/logaddexp_native.h>
|
|
||||||
#include <ATen/ops/logical_and_native.h>
|
#include <ATen/ops/logical_and_native.h>
|
||||||
#include <ATen/ops/logical_or_native.h>
|
#include <ATen/ops/logical_or_native.h>
|
||||||
#include <ATen/ops/logical_xor_native.h>
|
#include <ATen/ops/logical_xor_native.h>
|
||||||
|
|
@ -277,30 +275,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
|
||||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
|
||||||
MPSGraphTensor* sumTensor =
|
|
||||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
|
|
||||||
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
|
|
||||||
name:nil];
|
|
||||||
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
|
|
||||||
};
|
|
||||||
mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
|
||||||
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
|
||||||
MPSGraphTensor* sumTensor =
|
|
||||||
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
|
|
||||||
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
|
|
||||||
name:nil];
|
|
||||||
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
|
|
||||||
};
|
|
||||||
mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||||
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||||
|
|
|
||||||
|
|
@ -3622,8 +3622,7 @@
|
||||||
structured: True
|
structured: True
|
||||||
structured_inherits: TensorIteratorBase
|
structured_inherits: TensorIteratorBase
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: logaddexp_out
|
CPU, CUDA, MPS: logaddexp_out
|
||||||
MPS: logaddexp_out_mps
|
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: logaddexp(Tensor self, Tensor other) -> Tensor
|
- func: logaddexp(Tensor self, Tensor other) -> Tensor
|
||||||
|
|
@ -3635,8 +3634,7 @@
|
||||||
structured: True
|
structured: True
|
||||||
structured_inherits: TensorIteratorBase
|
structured_inherits: TensorIteratorBase
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: logaddexp2_out
|
CPU, CUDA, MPS: logaddexp2_out
|
||||||
MPS: logaddexp2_out_mps
|
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: logaddexp2(Tensor self, Tensor other) -> Tensor
|
- func: logaddexp2(Tensor self, Tensor other) -> Tensor
|
||||||
|
|
|
||||||
|
|
@ -624,6 +624,64 @@ inline T spherical_bessel_j0(T x) {
|
||||||
return static_cast<T>(::metal::sin(x) / x);
|
return static_cast<T>(::metal::sin(x) / x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> logaddexp(
|
||||||
|
T a,
|
||||||
|
T b) {
|
||||||
|
float a0 = static_cast<float>(a);
|
||||||
|
float b0 = static_cast<float>(b);
|
||||||
|
if (::metal::isinf(a0) && a0 == b0) {
|
||||||
|
return static_cast<T>(a0);
|
||||||
|
} else {
|
||||||
|
float m0 = ::metal::max(a0, b0);
|
||||||
|
return static_cast<T>(
|
||||||
|
m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The function is ported from mlx
|
||||||
|
template <typename T>
|
||||||
|
inline ::metal::enable_if_t<is_complex_v<T>, T> logaddexp(T a, T b) {
|
||||||
|
if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) ||
|
||||||
|
::metal::isnan(b.y)) {
|
||||||
|
return T(NAN, NAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
T maxval = a.x > b.x ? a : b;
|
||||||
|
T minval = a.x < b.x ? a : b;
|
||||||
|
constexpr auto inf = ::metal::numeric_limits<T>::infinity().x;
|
||||||
|
|
||||||
|
if (minval.x == -inf || maxval.x == inf) {
|
||||||
|
return maxval;
|
||||||
|
}
|
||||||
|
|
||||||
|
float2 maxval_ = static_cast<float2>(maxval);
|
||||||
|
float2 minval_ = static_cast<float2>(minval);
|
||||||
|
float m = ::metal::exp(minval_.x - maxval_.x);
|
||||||
|
float2 dexp{
|
||||||
|
m * ::metal::cos(minval_.y - maxval_.y),
|
||||||
|
m * ::metal::sin(minval_.y - maxval_.y),
|
||||||
|
};
|
||||||
|
return static_cast<T>(maxval_ + ::c10::metal::log1p(dexp));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T logaddexp2(T a, T b) {
|
||||||
|
constexpr auto log_2 = float(0.693147180559945309417232121458176);
|
||||||
|
constexpr auto inv_log_2 = float(1) / log_2;
|
||||||
|
float a0 = static_cast<float>(a);
|
||||||
|
float b0 = static_cast<float>(b);
|
||||||
|
if (::metal::isinf(a0) && a0 == b0) {
|
||||||
|
return static_cast<T>(a0);
|
||||||
|
} else {
|
||||||
|
float m0 = ::metal::max(a0, b0);
|
||||||
|
return static_cast<T>(
|
||||||
|
m0 +
|
||||||
|
::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) *
|
||||||
|
inv_log_2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline float xlog1py(T x, T y) {
|
inline float xlog1py(T x, T y) {
|
||||||
if (::metal::isnan(y)) {
|
if (::metal::isnan(y)) {
|
||||||
|
|
|
||||||
|
|
@ -322,6 +322,24 @@ inline float log1p(float x) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The function is ported from mlx
|
||||||
|
inline float2 log1p(float2 in) {
|
||||||
|
float x = in.x;
|
||||||
|
float y = in.y;
|
||||||
|
float zabs = ::metal::precise::sqrt(x * x + y * y);
|
||||||
|
float theta = ::metal::atan2(y, x + 1);
|
||||||
|
if (zabs < 0.5f) {
|
||||||
|
float r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return {x, theta};
|
||||||
|
}
|
||||||
|
return {0.5f * log1p(r), theta};
|
||||||
|
} else {
|
||||||
|
auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||||
|
return {::metal::log(z0), theta};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T1, typename T2 = T1>
|
template <typename T1, typename T2 = T1>
|
||||||
struct pair {
|
struct pair {
|
||||||
T1 first;
|
T1 first;
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,8 @@ if torch.backends.mps.is_available():
|
||||||
"log1p",
|
"log1p",
|
||||||
"log2",
|
"log2",
|
||||||
"log",
|
"log",
|
||||||
|
"logaddexp",
|
||||||
|
"logaddexp2",
|
||||||
"mH",
|
"mH",
|
||||||
"mT",
|
"mT",
|
||||||
"masked_fill",
|
"masked_fill",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user