diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 0539eab7950..0764b9d5e12 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -222,6 +222,13 @@ struct nextafter_functor { } }; +struct hypot_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(precise::sqrt(float(a) * a + float(b) * b)); + } +}; + // Complex binary functors struct polar_functor { template @@ -362,6 +369,7 @@ struct igammac_functor { REGISTER_OPMATH_BINARY_OP(NAME, half, half); \ REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) +REGISTER_FLOAT_BINARY_OP(hypot); REGISTER_FLOAT_BINARY_OP(copysign); REGISTER_INT2FLOAT_BINARY_OP(copysign); REGISTER_FLOAT_BINARY_OP(fmax); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 32b0fff8081..70211ceef07 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "igammac"); } +static void hypot_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "hypot"); +} + REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) @@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel) REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel) REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel) +REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 06b6edcff94..bffd7924326 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const } } -TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor - secondaryTensor:twoTensor - name:nil] - secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor - secondaryTensor:twoTensor - name:nil] - name:nil]; - return [mpsGraph squareRootWithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block); -} - TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9b31ee06f90..0bc89ef493d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10040,8 +10040,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hypot_out - MPS: hypot_out_mps + CPU, CUDA, MPS: hypot_out tags: pointwise - func: hypot(Tensor self, Tensor other) -> Tensor