mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor] Speedup maximum/minumum ops (#144581)
By relying on the fact that if either `a` or `b` is NaN (or both), than `a + b` would also be NaN. I.e. it replaces ```metal auto tmp2 = metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp0))) | metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp1))) ? static_cast<decltype(tmp0+tmp1)>(NAN) : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1)); ``` with ```metal auto tmp2 = metal::isnan(tmp0 + tmp1) ? tmp0 + tmp1 : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1)); ``` which according to MetalProfiler takes fewer instructions: <img width="520" alt="image" src="https://github.com/user-attachments/assets/54659392-012b-453e-9c02-c3c5f332074a" /> vs <img width="1031" alt="image" src="https://github.com/user-attachments/assets/55fcfa78-1ea5-4b0a-8154-d79b3e3cc400" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144581 Approved by: https://github.com/dcci, https://github.com/jhavukainen
This commit is contained in:
parent
a94ec0a9a5
commit
c7f12a4a7b
|
|
@ -114,19 +114,15 @@ class MetalOverrides(OpOverrides):
|
|||
def maximum(a: CSEVariable, b: CSEVariable) -> str:
|
||||
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
|
||||
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
|
||||
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
|
||||
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
|
||||
max_res = f"metal::max({typecast_a}, {typecast_b})"
|
||||
return f"{nan_check} ? {nan_value} : {max_res}"
|
||||
return f"metal::isnan({a} + {b}) ? {a} + {b} : {max_res}"
|
||||
|
||||
@staticmethod
|
||||
def minimum(a: CSEVariable, b: CSEVariable) -> str:
|
||||
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
|
||||
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
|
||||
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
|
||||
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
|
||||
min_res = f"metal::min({typecast_a}, {typecast_b})"
|
||||
return f"{nan_check} ? {nan_value} : {min_res}"
|
||||
return f"metal::isnan({a} + {b}) ? {a} + {b} : {min_res}"
|
||||
|
||||
@staticmethod
|
||||
def logical_or(a: CSEVariable, b: CSEVariable) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user