mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor] Implement argmax/argmin (#146429)
TODOs: - Find test with NaN - Report internal compiler error when running `test_argmax_argmin1` (which is actually not enough shared memory) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146429 Approved by: https://github.com/dcci ghstack dependencies: #146423, #146428
This commit is contained in:
parent
c591ad0c03
commit
3525b834f0
|
|
@ -51,5 +51,31 @@ T threadgroup_min(threadgroup T* data, unsigned size) {
|
|||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
long threadgroup_argmax(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
long rc = 0;
|
||||
for (auto idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] > data[rc]) {
|
||||
rc = idx;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_argmin(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
long rc = 0;
|
||||
for (auto idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] < data[rc]) {
|
||||
rc = idx;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ for test_name in [
|
|||
"test_any",
|
||||
"test_arange5",
|
||||
"test_argmax_min_int32",
|
||||
"test_argmax_argmin2",
|
||||
"test_avg_pool2d5",
|
||||
"test_avg_pool2d8",
|
||||
"test_builtins_round",
|
||||
|
|
|
|||
|
|
@ -436,7 +436,7 @@ class MetalKernel(SIMDKernel):
|
|||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
|
||||
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
|
||||
)
|
||||
if reduction_type in ["max", "min"]:
|
||||
if reduction_type in ["max", "min", "argmax", "argmin"]:
|
||||
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
|
||||
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
|
||||
return self.cse.generate(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user