[MPSInductor] Speedup maximum/minumum ops (#144581)

By relying on the fact that if either `a` or `b` is NaN (or both), than `a + b` would also be NaN.

I.e. it replaces
```metal
auto tmp2 = metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp0))) | metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp1))) ? static_cast<decltype(tmp0+tmp1)>(NAN) : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1));
```
with
```metal
auto tmp2 = metal::isnan(tmp0 + tmp1) ? tmp0 + tmp1 : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1));
```

which according to MetalProfiler takes fewer instructions:
<img width="520" alt="image" src="https://github.com/user-attachments/assets/54659392-012b-453e-9c02-c3c5f332074a" />
vs
<img width="1031" alt="image" src="https://github.com/user-attachments/assets/55fcfa78-1ea5-4b0a-8154-d79b3e3cc400" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144581
Approved by: https://github.com/dcci, https://github.com/jhavukainen
This commit is contained in:
Nikita Shulga 2025-01-10 22:58:00 +00:00 committed by PyTorch MergeBot
parent a94ec0a9a5
commit c7f12a4a7b

View File

@ -114,19 +114,15 @@ class MetalOverrides(OpOverrides):
def maximum(a: CSEVariable, b: CSEVariable) -> str:
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
max_res = f"metal::max({typecast_a}, {typecast_b})"
return f"{nan_check} ? {nan_value} : {max_res}"
return f"metal::isnan({a} + {b}) ? {a} + {b} : {max_res}"
@staticmethod
def minimum(a: CSEVariable, b: CSEVariable) -> str:
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
min_res = f"metal::min({typecast_a}, {typecast_b})"
return f"{nan_check} ? {nan_value} : {min_res}"
return f"metal::isnan({a} + {b}) ? {a} + {b} : {min_res}"
@staticmethod
def logical_or(a: CSEVariable, b: CSEVariable) -> str: