mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Move hypot to Metal (#166216)
Which also prevents crashes, when invoked for integer types, for example, before this change following crashes ``` python -c "import torch; print(torch.hypot(torch.randint(0, 10, (3,), device='mps'), torch.randint(0, 10, (3,), device='mps')))" *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '*** -[__NSDictionaryM setObject:forKey:]: object cannot be nil (key: squareRoot_i64)' *** First throw call stack: ( 0 CoreFoundation 0x0000000194d33ae0 __exceptionPreprocess + 176 1 libobjc.A.dylib 0x00000001947f6b90 objc_exception_throw + 88 2 CoreFoundation 0x0000000194c7d884 -[__NSDictionaryM setObject:forKey:] + 1288 3 MPSCore 0x00000001a1187d0c _ZN12MPSKernelDAG15duodenaryCoreOpEP10BaseTensorS1_S1_S1_S1_S1_S1_S1_S1_S1_S1_S1_RKNSt3__16vectorIlNS2_9allocatorIlEEEE11MPSDataTypePKc + 37044 4 MPSCore 0x00000001a113fab0 _ZN12MPSKernelDAGD0Ev + 4256 5 MPSCore 0x00000001a1139f6c _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 8 6 MPSCore 0x00000001a113c7a4 _ZN12MPSKernelDAG13getDAGAndHashEPU21objcproto10MTLLibrary11objc_objectP14MPSDAGKernelOpP19NSMutableDictionaryIP8NSStringPU22objcproto11MTLFunction11objc_objectEP14NSMutableArrayIS6_ERDv4_yPb + 1 7 MPSCore 0x00000001a11c03c8 _ZN10MPSLibrary19CreateUberShaderKeyEP8NSStringRK23MPSFunctionConstantListyPFPU22objcproto11MTLFunction11objc_objectPU21objcproto10MTLLibrary11objc_objectPK13MPSKernelInfoS4_RK33MPSFunctionConstr 8 MPSNDArray 0x00000001a27b546c MPSSetResourcesOnCommandEncoder + 154176 9 MPSNDArray 0x00000001a27967d8 MPSSetResourcesOnCommandEncoder + 28076 10 MPSNDArray 0x00000001a2798ec8 MPSSetResourcesOnCommandEncoder + 38044 11 MetalPerformanceShadersGraph 0x00000001f97689ac _ZN3GPU17IdentityOpHandler15encodeNDArrayOpEPNS_16EncodeDescriptorEP7NSArray + 436 12 MetalPerformanceShadersGraph 0x00000001f977f93c _ZN3GPU17StitchedOpHandler8encodeOpEPNS_16EncodeDescriptorE + 924 13 MetalPerformanceShadersGraph 0x00000001f9544898 _ZN16GPURegionRuntime5runOpIN3GPU23AbsoluteSquareOpHandlerEEEvPN4mlir9OperationEPNS1_16EncodeDescriptorE + 120 14 MetalPerformanceShadersGraph 0x00000001f9543894 _ZN16GPURegionRuntime8encodeOpEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 4700 15 MetalPerformanceShadersGraph 0x00000001f954251c _ZN16GPURegionRuntime29encodeOpWithCommitAndContinueEPN4mlir9OperationEPN3GPU16EncodeDescriptorE + 92 16 MetalPerformanceShadersGraph 0x00000001f954189c _ZN16GPURegionRuntime11evaluateOpsEPN3GPU16EncodeDescriptorEP7NSArrayIP18MPSGraphTensorDataES7_ + 3572 17 MetalPerformanceShadersGraph 0x00000001f953f7b4 _ZN10MPSRuntime11evaluateOpsEN4mlir4func6FuncOpEP21RuntimeSpecializationP7NSArrayIP18MPSGraphTensorDataES9_P37MPSGraphExecutableExecutionDescriptorP16MPSCommandBufferbbbPb + 824 18 MetalPerformanceShadersGraph 0x00000001f988dd38 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feeds:results:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 3848 19 MetalPerformanceShadersGraph 0x00000001f988ca04 -[MPSGraphExecutable runInternalWithDevice:commandBuffer:feedsDictionary:resultsDictionary:executableExecutionDescriptor:mpsGraphOwnedCommandBuffer:] + 608 20 MetalPerformanceShadersGraph 0x00000001f9728aa0 -[MPSGraph runInternalWithMPSCommandBuffer:feeds:targetTensors:targetOperations:resultsDictionary:executionDescriptor:mpsGraphOwnedCommandBuffer:] + 320 21 MetalPerformanceShadersGraph 0x00000001f9727b58 -[MPSGraph encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:] + 188 22 libtorch_cpu.dylib 0x00000001556c9478 ___ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE_block_invoke + 128 23 libdispatch.dylib 0x0000000194a3985c _dispatch_client_callout + 16 24 libdispatch.dylib 0x0000000194a2f7a8 _dispatch_lane_barrier_sync_invoke_and_complete + 56 25 libtorch_cpu.dylib 0x00000001556c93e0 _ZN2at3mps9MPSStream15executeMPSGraphEP8MPSGraphP12NSDictionaryS5_NS0_8SyncTypeE + 160 26 libtorch_cpu.dylib 0x00000001556fd0f4 _ZN2at6native3mpsL14binaryOpTensorERKNS_6TensorES4_S4_NSt3__112basic_stringIcNS5_11char_traitsIcEENS5_9allocatorIcEEEEU13block_pointerFP14MPSGraphTensorPNS1_19BinaryOpCachedGraphESD_SD_E + 3040 27 libtorch_cpu.dylib 0x00000001556ff680 _ZN2at6native24structured_hypot_out_mps4implERKNS_6TensorES4_S4_ + 84 28 libtorch_cpu.dylib 0x00000001522682e4 _ZN2at12_GLOBAL__N_117wrapper_MPS_hypotERKNS_6TensorES3_ + 216 29 libtorch_cpu.dylib 0x0000000153a1378c _ZN3c104impl28wrap_kernel_functor_unboxed_INS0_6detail24WrapFunctionIntoFunctor_INS_26CompileTimeFunctionPointerIFN2at6TensorENS_14DispatchKeySetERKS6_S9_EXadL_ZN5torch8autograd12VariableType12_G 30 libtorch_cpu.dylib 0x0000000151241714 _ZN2at4_ops5hypot4callERKNS_6TensorES4_ + 304 31 libtorch_python.dylib 0x0000000105d9a848 _ZN5torch8autogradL17THPVariable_hypotEP7_objectS2_S2_ + 752 32 Python 0x00000001036afa7c cfunction_call + 72 33 Python 0x000000010365db08 _PyObject_MakeTpCall + 124 34 Python 0x0000000103750f40 _PyEval_EvalFrameDefault + 23304 35 Python 0x000000010374b1c8 PyEval_EvalCode + 184 36 Python 0x00000001037ab8bc run_eval_code_obj + 88 37 Python 0x00000001037a9994 run_mod + 132 38 Python 0x00000001037a8fdc PyRun_StringFlags + 124 39 Python 0x00000001037a8f08 PyRun_SimpleStringFlags + 64 40 Python 0x00000001037cd464 Py_RunMain + 716 41 Python 0x00000001037cd950 pymain_main + 304 42 Python 0x00000001037cd9f0 Py_BytesMain + 40 43 dyld 0x0000000194836b98 start + 6076 ) libc++abi: terminating due to uncaught exception of type NSException ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166216 Approved by: https://github.com/Skylion007 ghstack dependencies: #166210
This commit is contained in:
parent
661a56002f
commit
b0e9c86971
|
|
@ -222,6 +222,13 @@ struct nextafter_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct hypot_functor {
|
||||||
|
template <typename T>
|
||||||
|
inline T operator()(const T a, const T b) {
|
||||||
|
return static_cast<T>(precise::sqrt(float(a) * a + float(b) * b));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Complex binary functors
|
// Complex binary functors
|
||||||
struct polar_functor {
|
struct polar_functor {
|
||||||
template <typename U>
|
template <typename U>
|
||||||
|
|
@ -362,6 +369,7 @@ struct igammac_functor {
|
||||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||||
|
|
||||||
|
REGISTER_FLOAT_BINARY_OP(hypot);
|
||||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||||
REGISTER_FLOAT_BINARY_OP(fmax);
|
REGISTER_FLOAT_BINARY_OP(fmax);
|
||||||
|
|
|
||||||
|
|
@ -202,6 +202,10 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) {
|
||||||
lib.exec_binary_kernel(iter, "igammac");
|
lib.exec_binary_kernel(iter, "igammac");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void hypot_mps_kernel(TensorIteratorBase& iter) {
|
||||||
|
lib.exec_binary_kernel(iter, "hypot");
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||||
|
|
@ -229,4 +233,5 @@ REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
|
||||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
||||||
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
|
REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel)
|
||||||
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
|
REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel)
|
||||||
|
REGISTER_DISPATCH(hypot_stub, &hypot_mps_kernel)
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
#include <ATen/ops/eq_native.h>
|
#include <ATen/ops/eq_native.h>
|
||||||
#include <ATen/ops/ge_native.h>
|
#include <ATen/ops/ge_native.h>
|
||||||
#include <ATen/ops/gt_native.h>
|
#include <ATen/ops/gt_native.h>
|
||||||
#include <ATen/ops/hypot_native.h>
|
|
||||||
#include <ATen/ops/le_native.h>
|
#include <ATen/ops/le_native.h>
|
||||||
#include <ATen/ops/logaddexp2_native.h>
|
#include <ATen/ops/logaddexp2_native.h>
|
||||||
#include <ATen/ops/logaddexp_native.h>
|
#include <ATen/ops/logaddexp_native.h>
|
||||||
|
|
@ -278,22 +277,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
|
||||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
|
||||||
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType];
|
|
||||||
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor
|
|
||||||
secondaryTensor:twoTensor
|
|
||||||
name:nil]
|
|
||||||
secondaryTensor:[mpsGraph powerWithPrimaryTensor:secondaryCastTensor
|
|
||||||
secondaryTensor:twoTensor
|
|
||||||
name:nil]
|
|
||||||
name:nil];
|
|
||||||
return [mpsGraph squareRootWithTensor:sumTensor name:nil];
|
|
||||||
};
|
|
||||||
mps::binaryOpTensor(self, other, output, "hypot_out_mps", hypot_op_block);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||||
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||||
|
|
|
||||||
|
|
@ -10040,8 +10040,7 @@
|
||||||
structured: True
|
structured: True
|
||||||
structured_inherits: TensorIteratorBase
|
structured_inherits: TensorIteratorBase
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: hypot_out
|
CPU, CUDA, MPS: hypot_out
|
||||||
MPS: hypot_out_mps
|
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: hypot(Tensor self, Tensor other) -> Tensor
|
- func: hypot(Tensor self, Tensor other) -> Tensor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user