From 12577064dddfc6f5daf66c5b5a73cb418a588f20 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 27 Oct 2025 18:07:20 -0700 Subject: [PATCH] [MPS] Fix crash when max/min ops called for complex types (#166214) Raise an exception, as it's meaningless and results in segfault otherwise: ``` % python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()" (mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex>' (mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex>, tensor<1xsi32>) -> tensor<1xcomplex> (mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex>' (mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex>, tensor<1xsi32>) -> tensor<1xcomplex> /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification' zsh: abort python -c ``` To be tested by `test_ops.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166214 Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #166272 --- aten/src/ATen/native/mps/operations/ReduceOps.mm | 3 +++ aten/src/ATen/native/mps/operations/Sort.mm | 1 + 2 files changed, 4 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index f4e469b79cb..3747f314adf 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -1028,15 +1028,18 @@ TORCH_IMPL_FUNC(prod_out_mps) } TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types"); reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); } TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types"); reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); } TORCH_IMPL_FUNC(aminmax_out_mps) (const Tensor& input_t, std::optional dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types"); reduction_out_mps(input_t, dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt, keepdim, diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index b6a07f14704..898acacdb76 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -31,6 +31,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v indices.copy_(values.toType(at::ScalarType::Long)); return; } + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types"); // issue #154890, raising error to prevent crash within MPSGraph until // workaround is implemented. TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");