mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add error message with assert to topK if ndims() - dim > 4 (#155475)
Addressing #154890 Not really a proper fix but at least it's more informative than the current crash. For a more long term solution I'm testing if we can use the TopK API released in MacOS14 as it does not have the same MPSScan op issue that the Sort and ArgSort are hitting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155475 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
049dc48d1e
commit
20a74c370b
|
|
@ -87,6 +87,10 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
|||
return;
|
||||
}
|
||||
|
||||
// issue #154890, raising error to prevent crash within MPSGraph until
|
||||
// workaround is implemented.
|
||||
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
|
|
|
|||
|
|
@ -8415,6 +8415,14 @@ class TestTopK(TestCase):
|
|||
with self.subTest(shape=shape, largest_val=largest_val):
|
||||
self._test_topk(shape, largest_val)
|
||||
|
||||
def test_topk_gt_4d(self):
|
||||
a = torch.ones(5, 4, 3, 2, 1, dtype=torch.float).to('mps')
|
||||
try:
|
||||
t_mps = torch.ops.aten.topk(a, k=5, dim=0)
|
||||
except Exception as e:
|
||||
e_string = str(e)
|
||||
self.assertEqual(e_string, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890")
|
||||
|
||||
class TestNNMPS(NNTestCase):
|
||||
|
||||
def _create_basic_net(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user