Fix meta impl for topk (#147017)

Topk in this context is always size-like so we should use torch._check_is_size. Fixes some issue in https://github.com/pytorch/pytorch/issues/146990

Differential Revision: [D69545983](https://our.internmc.facebook.com/intern/diff/D69545983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147017
Approved by: https://github.com/ydwu4
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-02-12 12:14:14 -08:00 committed by PyTorch MergeBot
parent 821422018a
commit c159723c39

View File

@ -6507,7 +6507,8 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True):
# From aten/src/ATen/native/Sorting.cpp
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
sliceSize = 1 if self.dim() == 0 else self.size(dim)
torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
torch._check_is_size(k)
torch._check(k <= sliceSize, lambda: "k not in range for dimension")
topKSize = list(self.shape)
if len(topKSize) > 0: