From b0e9c86971634f4a109377278d9acd66a5e3da10 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 24 Oct 2025 21:56:15 -0700 Subject: [PATCH] [MPS] Move hypot to Metal (#166216) Which also prevents crashes, when invoked for integer types, for example, before this change following crashes ``` python -c "import torch; print(torch.hypot(torch.randint(0, 10, (3,), device='mps'), torch.randint(0, 10, (3,), device='mps')))" *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '*** -[__NSDictionaryM setObject:forKey:]: object cannot be nil (key: squareRoot_i64)' *** First throw call stack: ( 0 CoreFoundation 0x0000000194d33ae0 __exceptionPreprocess + 176 1 libobjc.A.dylib 0x00000001947f6b90 objc_exception_throw + 88 2 CoreFoundation 0x0000000194c7d884 -[__NSDictionaryM setObject:forKey:] + 1288 3 MPSCore 0x00000001a1187d0c _ZN12MPSKernelDAG15duodenaryCoreOpEP10BaseTensorS1_S1_S1_S1_S1_S1_S1_S1_S1_S1_S1_RKNSt3__16vectorIlNS2_9allocatorIlEEEE11MPSDataTypePKc + 37044 4 MPSCore 0x00000001a113fab0 _ZN12MPSKernelDAGD0Ev + 4256 5 MPSCore 0x00000001a1139f6c _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 8 6 MPSCore 0x00000001a113c7a4 _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 1 7 MPSCore 0x00000001a11c03c8 _ZN10MPSLibrary19CreateUberShaderKeyEP8NSStringRK23MPSFunctionConstantListyPFPU22objcproto11MTLFunction11objc_objectPU21objcproto10MTLLibrary11objc_objectPK13MPSKernelInfoS4_RK33MPSFunctionConstr 8 MPSNDArray 0x00000001a27b546c MPSSetResourcesOnCommandEncoder + 154176 9 MPSNDArray 0x00000001a27967d8 MPSSetResourcesOnCommandEncoder + 28076 10 MPSNDArray 0x00000001a2798ec8 MPSSetResourcesOnCommandEncoder + 38044 11 MetalPerformanceShadersGraph 0x00000001f97689ac _ZN3GPU17IdentityOpHandler15encodeNDArrayOpEPNS_16EncodeDescriptorEP7NSArray + 436 12 MetalPerformanceShadersGraph 0x00000001f977f93c _ZN3GPU17StitchedOpHandler8encodeOpEPNS_16EncodeDescriptorE + 924 13 MetalPerformanceShadersGraph 0x00000001f9544898 _ZN16GPURegionRuntime5runOpIN3GPU23AbsoluteSquareOpHandlerEEEvPN4mlir9OperationEPNS1_16EncodeDescriptorE + 120 14 MetalPerformanceShadersGraph 0x00000001f9543894 _ZN16GPURegionRuntime8encodeOpEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 4700 15 MetalPerformanceShadersGraph 0x00000001f954251c _ZN16GPURegionRuntime29encodeOpWithCommitAndContinueEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 92 16 MetalPerformanceShadersGraph 0x00000001f954189c _ZN16GPURegionRuntime11evaluateOpsEPN3GPU16EncodeDescriptorEP7NSArrayIP18MPSGraphTensorDataES7_ + 3572 17 MetalPerformanceShadersGraph 0x00000001f953f7b4 _ZN10MPSRuntime11evaluateOpsEN4mlir4func6FuncOpEP21RuntimeSpecializationP7NSArrayIP18MPSGraphTensorDataES9_P37MPSGraphExecutableExecutionDescriptorP16MPSCommandBufferbbbPb + 824 18 MetalPerformanceShadersGraph 0x00000001f988dd38 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feeds:results:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 3848 19 MetalPerformanceShadersGraph 0x00000001f988ca04 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feedsDictionary:resultsDictionary:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 608 20 MetalPerformanceShadersGraph 0x00000001f9728aa0 -[MPSGraph runInternalWithMPSCommandBuffer:feeds:targetTensors:targetOperations:resultsDictionary:executionDescriptor:mpsGraphOwnedCommandBuffer:] + 320 21 MetalPerformanceShadersGraph 0x00000001f9727b58 -[MPSGraph encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:] + 188 22 libtorch_cpu.dylib 0x00000001556c9478 ___ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE_block_invoke + 128 23 libdispatch.dylib 0x0000000194a3985c _dispatch_client_callout + 16 24 libdispatch.dylib 0x0000000194a2f7a8 _dispatch_lane_barrier_sync_invoke_and_complete + 56 25 libtorch_cpu.dylib 0x00000001556c93e0 _ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE + 160 26 libtorch_cpu.dylib 0x00000001556fd0f4 _ZN2at6native3mpsL14binaryOpTensorERKNS_6TensorES4_S4_NSt3__112basic_stringIcNS5_11char_traitsIcEENS5_9allocatorIcEEEEU13block_pointerFP14MPSGraphTensorPNS1_19BinaryOpCachedGraphESD_SD_E + 3040 27 libtorch_cpu.dylib 0x00000001556ff680 _ZN2at6native24structured_hypot_out_mps4implERKNS_6TensorES4_S4_ + 84 28 libtorch_cpu.dylib 0x00000001522682e4 _ZN2at12_GLOBAL__N_117wrapper_MPS_hypotERKNS_6TensorES3_ + 216 29 libtorch_cpu.dylib 0x0000000153a1378c _ZN3c104impl28wrap_kernel_functor_unboxed_INS0_6detail24WrapFunctionIntoFunctor_INS_26CompileTimeFunctionPointerIFN2at6TensorENS_14DispatchKeySetERKS6_S9_EXadL_ZN5torch8autograd12VariableType12_G 30 libtorch_cpu.dylib 0x0000000151241714 _ZN2at4_ops5hypot4callERKNS_6TensorES4_ + 304 31 libtorch_python.dylib 0x0000000105d9a848 _ZN5torch8autogradL17THPVariable_hypotEP7_objectS2_S2_ + 752 32 Python 0x00000001036afa7c cfunction_call + 72 33 Python 0x000000010365db08 _PyObject_MakeTpCall + 124 34 Python 0x0000000103750f40 _PyEval_EvalFrameDefault + 23304 35 Python 0x000000010374b1c8 PyEval_EvalCode + 184 36 Python 0x00000001037ab8bc run_eval_code_obj + 88 37 Python 0x00000001037a9994 run_mod + 132 38 Python 0x00000001037a8fdc PyRun_StringFlags + 124 39 Python 0x00000001037a8f08 PyRun_SimpleStringFlags + 64 40 Python 0x00000001037cd464 Py_RunMain + 716 41 Python 0x00000001037cd950 pymain_main + 304 42 Python 0x00000001037cd9f0 Py_BytesMain + 40 43 dyld 0x0000000194836b98 start + 6076 ) libc++abi: terminating due to uncaught exception of type NSException ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166216 Approved by: https://github.com/Skylion007 ghstack dependencies: #166210 --- .../ATen/native/mps/kernels/BinaryKernel.metal | 8 ++++++++ .../ATen/native/mps/operations/BinaryKernel.mm | 5 +++++ .../src/ATen/native/mps/operations/BinaryOps.mm | 17 ----------------- aten/src/ATen/native/native_functions.yaml | 3 +-- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 0539eab7950..0764b9d5e12 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -222,6 +222,13 @@ struct nextafter_functor { } }; +struct hypot_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(precise::sqrt(float(a) * a + float(b) * b)); + } +}; + // Complex binary functors struct polar_functor { template @@ -362,6 +369,7 @@ struct igammac_functor { REGISTER_OPMATH_BINARY_OP(NAME, half, half); \ REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) +REGISTER_FLOAT_BINARY_OP(hypot); REGISTER_FLOAT_BINARY_OP(copysign); REGISTER_INT2FLOAT_BINARY_OP(copysign); REGISTER_FLOAT_BINARY_OP(fmax); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 32b0fff8081..70211ceef07 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "igammac"); } +static void hypot_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "hypot"); +} + REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) @@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel) REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel) REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel) +REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 06b6edcff94..bffd7924326 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const } } -TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor - secondaryTensor:twoTensor - name:nil] - secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor - secondaryTensor:twoTensor - name:nil] - name:nil]; - return [mpsGraph squareRootWithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block); -} - TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9b31ee06f90..0bc89ef493d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10040,8 +10040,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hypot_out - MPS: hypot_out_mps + CPU, CUDA, MPS: hypot_out tags: pointwise - func: hypot(Tensor self, Tensor other) -> Tensor