Add error message with assert to topK if ndims() - dim > 4 (#155475)

Addressing #154890

Not really a proper fix but at least it's more informative than the current crash.

For a more long term solution I'm testing if we can use the TopK API released in MacOS14 as it does not have the same MPSScan op issue that the Sort and ArgSort are hitting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155475
Approved by: https://github.com/kulinseth
This commit is contained in:
Joona Havukainen 2025-06-13 21:10:00 +00:00 committed by PyTorch MergeBot
parent 049dc48d1e
commit 20a74c370b
2 changed files with 12 additions and 0 deletions

View File

@ -87,6 +87,10 @@ TORCH_IMPL_FUNC(topk_out_mps)
return;
}
// issue #154890, raising error to prevent crash within MPSGraph until
// workaround is implemented.
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}

View File

@ -8415,6 +8415,14 @@ class TestTopK(TestCase):
with self.subTest(shape=shape, largest_val=largest_val):
self._test_topk(shape, largest_val)
def test_topk_gt_4d(self):
a = torch.ones(5, 4, 3, 2, 1, dtype=torch.float).to('mps')
try:
t_mps = torch.ops.aten.topk(a, k=5, dim=0)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890")
class TestNNMPS(NNTestCase):
def _create_basic_net(self):