[MPSInductor][BE] Only include headers when needed (#152266)

Store headers used by shader in `MetalKernel.headers`
Add headers when function depending on it gets invoked
Generate majority of a special ops from template
Delete two unused functors: `entr` and `xlog1py` as they are decomposed by inductor anyway

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152266
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci, https://github.com/cyyever
This commit is contained in:
Nikita Shulga 2025-04-26 15:16:34 -07:00 committed by PyTorch MergeBot
parent a0d440a26a
commit cbcc03c2ad

View File

@ -2,6 +2,7 @@
# Just an early prototype that shows that one can compile elementwise ops into a Metal shader
from __future__ import annotations
import functools
import itertools
import logging
import math
@ -11,6 +12,7 @@ import sympy
from sympy.printing.precedence import PRECEDENCE
import torch
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_
from torch.utils._sympy.value_ranges import ValueRanges
@ -274,42 +276,6 @@ class MetalOverrides(OpOverrides):
def cos(x: CSEVariable) -> str:
return f"metal::precise::cos({x})"
@staticmethod
def i0(x: CSEVariable) -> str:
return f"c10::metal::i0({x})"
@staticmethod
def i0e(x: CSEVariable) -> str:
return f"c10::metal::i0e({x})"
@staticmethod
def i1(x: CSEVariable) -> str:
return f"c10::metal::i1({x})"
@staticmethod
def i1e(x: CSEVariable) -> str:
return f"c10::metal::i1e({x})"
@staticmethod
def erf(x: CSEVariable) -> str:
return f"c10::metal::erf({x})"
@staticmethod
def erfinv(x: CSEVariable) -> str:
return f"c10::metal::erfinv({x})"
@staticmethod
def lgamma(x: CSEVariable) -> str:
return f"c10::metal::log_gamma({x})"
@staticmethod
def polygamma(x: CSEVariable, y: CSEVariable) -> str:
return f"c10::metal::polygamma({x}, {y})"
@staticmethod
def digamma(x: CSEVariable) -> str:
return f"c10::metal::digamma({x})"
@staticmethod
def tan(x: CSEVariable) -> str:
return f"metal::tan({x})"
@ -391,16 +357,19 @@ class MetalOverrides(OpOverrides):
@staticmethod
def rand(seed: CSEVariable, offset: CSEVariable) -> str:
V.kernel.headers.add("random")
return f"c10::metal::rand({seed}, {offset})"
@staticmethod
def randn(seed: CSEVariable, offset: CSEVariable) -> str:
V.kernel.headers.add("random")
return f"c10::metal::randn({seed}, {offset})"
@staticmethod
def randint64(
seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable
) -> str:
V.kernel.headers.add("random")
return f"c10::metal::randint64({seed}, {offset}, {low}, {high})"
@staticmethod
@ -413,88 +382,75 @@ class MetalOverrides(OpOverrides):
cast_b = f"static_cast<decltype({a}+{b})>({b})"
return f"metal::pow({cast_a}, {cast_b})"
@staticmethod
def zeta(a: CSEVariable, b: CSEVariable) -> str:
return f"c10::metal::zeta({a}, {b})"
def _special_unary(self, a: CSEVariable, name: str) -> str:
V.kernel.headers.add("special_math")
return f"c10::metal::{name}({a})"
@staticmethod
def spherical_bessel_j0(x: CSEVariable) -> str:
return f"c10::metal::spherical_bessel_j0({x})"
def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str:
V.kernel.headers.add("special_math")
return f"c10::metal::{name}({a}, {b})"
@staticmethod
def xlog1py(x: CSEVariable) -> str:
return f"c10::metal::xlog1py({x})"
@classmethod
def _initialize_special_ops(cls) -> None:
# Unary special ops
for name in [
"erf",
"erfinv",
"i0",
"i0e",
"i1",
"i1e",
"digamma",
"spherical_bessel_j0",
]:
setattr(cls, name, functools.partialmethod(cls._special_unary, name=name))
@staticmethod
def entr(x: CSEVariable) -> str:
return f"c10::metal::entr({x})"
cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma") # type: ignore[assignment]
@staticmethod
def bessel_j0(x: CSEVariable) -> str:
return f"c10::metal::bessel_j0_forward({x})"
# Unary special ops with forward in method name
for name in [
"bessel_j0",
"bessel_j1",
"bessel_y0",
"bessel_y1",
"modified_bessel_i0",
"modified_bessel_i1",
"modified_bessel_k0",
"modified_bessel_k1",
"scaled_modified_bessel_k0",
"scaled_modified_bessel_k1",
]:
setattr(
cls,
name,
functools.partialmethod(cls._special_unary, name=name + "_forward"),
)
@staticmethod
def bessel_j1(x: CSEVariable) -> str:
return f"c10::metal::bessel_j1_forward({x})"
# Binary special ops
for name in [
"polygamma",
"zeta",
]:
setattr(cls, name, functools.partialmethod(cls._special_binary, name=name))
@staticmethod
def bessel_y0(x: CSEVariable) -> str:
return f"c10::metal::bessel_y0_forward({x})"
@staticmethod
def bessel_y1(x: CSEVariable) -> str:
return f"c10::metal::bessel_y1_forward({x})"
@staticmethod
def modified_bessel_i0(x: CSEVariable) -> str:
return f"c10::metal::modified_bessel_i0_forward({x})"
@staticmethod
def modified_bessel_i1(x: CSEVariable) -> str:
return f"c10::metal::modified_bessel_i1_forward({x})"
@staticmethod
def modified_bessel_k0(x: CSEVariable) -> str:
return f"c10::metal::modified_bessel_k0_forward({x})"
@staticmethod
def modified_bessel_k1(x: CSEVariable) -> str:
return f"c10::metal::modified_bessel_k1_forward({x})"
@staticmethod
def scaled_modified_bessel_k0(x: CSEVariable) -> str:
return f"c10::metal::scaled_modified_bessel_k0_forward({x})"
@staticmethod
def scaled_modified_bessel_k1(x: CSEVariable) -> str:
return f"c10::metal::scaled_modified_bessel_k1_forward({x})"
@staticmethod
def chebyshev_polynomial_t(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::chebyshev_polynomial_t_forward({x}, {n})"
@staticmethod
def chebyshev_polynomial_u(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::chebyshev_polynomial_u_forward({x}, {n})"
@staticmethod
def chebyshev_polynomial_v(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::chebyshev_polynomial_v_forward({x}, {n})"
@staticmethod
def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})"
@staticmethod
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
@staticmethod
def hermite_polynomial_he(x: CSEVariable, n: CSEVariable) -> str:
return f"c10::metal::hermite_polynomial_he_forward({x}, {n})"
# Binary special ops with forward in method name
for name in [
"chebyshev_polynomial_t",
"chebyshev_polynomial_u",
"chebyshev_polynomial_v",
"chebyshev_polynomial_w",
"hermite_polynomial_h",
"hermite_polynomial_he",
]:
setattr(
cls,
name,
functools.partialmethod(cls._special_binary, name=name + "_forward"),
)
MetalOverrides._initialize_pointwise_overrides("mps")
MetalOverrides._initialize_special_ops()
class MetalKernel(SIMDKernel):
@ -508,6 +464,7 @@ class MetalKernel(SIMDKernel):
pexpr = PythonPrinter().doprint
sexpr = MetalExprPrinter().doprint
kexpr = sexpr
headers: OrderedSet[str] = OrderedSet(["utils"])
def __init__(
self,
@ -545,6 +502,7 @@ class MetalKernel(SIMDKernel):
if mode is None:
line = f"{var}[{self.index_to_str(index)}] = {cast_val};"
elif mode == "atomic_add":
self.headers.add("atomic")
atomic_type = f"c10::metal::AtomicType<{dtype_str}>"
cast_var = f"reinterpret_cast<device {atomic_type}::type *>({var})"
line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});"
@ -632,6 +590,9 @@ class MetalKernel(SIMDKernel):
"threadgroup_barrier(metal::mem_flags::mem_threadgroup);"
)
return acc
self.headers.add("reduction_utils")
if reduction_type in ["prod", "sum"]:
acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype]
acc_buf = self._new_idxvar(
@ -803,17 +764,8 @@ class MetalKernel(SIMDKernel):
code.writeline('compile_mps_shader("""')
idx_vars = self.active_range_trees()
with code.indent():
code.splice(
"""
#include <c10/metal/atomic.h>
#include <c10/metal/random.h>
#include <c10/metal/special_math.h>
#include <c10/metal/utils.h>
""",
strip=True,
)
if self.inside_reduction:
code.writeline("#include <c10/metal/reduction_utils.h>")
for header in self.headers:
code.writeline(f"#include <c10/metal/{header}.h>")
if self.inside_reduction:
total_reduction_size = math.prod(
t.numel for t in self.range_trees if t.is_reduction