[MPSInductor] Implement argmax/argmin (#146429)

TODOs:
 - Find test with NaN
 - Report internal compiler error when running `test_argmax_argmin1` (which is actually not enough shared memory)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146429
Approved by: https://github.com/dcci
ghstack dependencies: #146423, #146428
This commit is contained in:
Nikita Shulga 2025-02-04 10:54:08 -08:00 committed by PyTorch MergeBot
parent c591ad0c03
commit 3525b834f0
3 changed files with 28 additions and 1 deletions

View File

@ -51,5 +51,31 @@ T threadgroup_min(threadgroup T* data, unsigned size) {
return rc;
}
template <typename T>
long threadgroup_argmax(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
long rc = 0;
for (auto idx = 1; idx < size; ++idx) {
if (data[idx] > data[rc]) {
rc = idx;
}
}
return rc;
}
template <typename T>
T threadgroup_argmin(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
long rc = 0;
for (auto idx = 1; idx < size; ++idx) {
if (data[idx] < data[rc]) {
rc = idx;
}
}
return rc;
}
} // namespace metal
} // namespace c10

View File

@ -123,6 +123,7 @@ for test_name in [
"test_any",
"test_arange5",
"test_argmax_min_int32",
"test_argmax_argmin2",
"test_avg_pool2d5",
"test_avg_pool2d8",
"test_builtins_round",

View File

@ -436,7 +436,7 @@ class MetalKernel(SIMDKernel):
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
)
if reduction_type in ["max", "min"]:
if reduction_type in ["max", "min", "argmax", "argmin"]:
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
return self.cse.generate(