mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor][BE] Only include headers when needed (#152266)
Store headers used by shader in `MetalKernel.headers` Add headers when function depending on it gets invoked Generate majority of a special ops from template Delete two unused functors: `entr` and `xlog1py` as they are decomposed by inductor anyway Pull Request resolved: https://github.com/pytorch/pytorch/pull/152266 Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci, https://github.com/cyyever
This commit is contained in:
parent
a0d440a26a
commit
cbcc03c2ad
|
|
@ -2,6 +2,7 @@
|
|||
# Just an early prototype that shows that one can compile elementwise ops into a Metal shader
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
|
@ -11,6 +12,7 @@ import sympy
|
|||
from sympy.printing.precedence import PRECEDENCE
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
|
@ -274,42 +276,6 @@ class MetalOverrides(OpOverrides):
|
|||
def cos(x: CSEVariable) -> str:
|
||||
return f"metal::precise::cos({x})"
|
||||
|
||||
@staticmethod
|
||||
def i0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::i0({x})"
|
||||
|
||||
@staticmethod
|
||||
def i0e(x: CSEVariable) -> str:
|
||||
return f"c10::metal::i0e({x})"
|
||||
|
||||
@staticmethod
|
||||
def i1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::i1({x})"
|
||||
|
||||
@staticmethod
|
||||
def i1e(x: CSEVariable) -> str:
|
||||
return f"c10::metal::i1e({x})"
|
||||
|
||||
@staticmethod
|
||||
def erf(x: CSEVariable) -> str:
|
||||
return f"c10::metal::erf({x})"
|
||||
|
||||
@staticmethod
|
||||
def erfinv(x: CSEVariable) -> str:
|
||||
return f"c10::metal::erfinv({x})"
|
||||
|
||||
@staticmethod
|
||||
def lgamma(x: CSEVariable) -> str:
|
||||
return f"c10::metal::log_gamma({x})"
|
||||
|
||||
@staticmethod
|
||||
def polygamma(x: CSEVariable, y: CSEVariable) -> str:
|
||||
return f"c10::metal::polygamma({x}, {y})"
|
||||
|
||||
@staticmethod
|
||||
def digamma(x: CSEVariable) -> str:
|
||||
return f"c10::metal::digamma({x})"
|
||||
|
||||
@staticmethod
|
||||
def tan(x: CSEVariable) -> str:
|
||||
return f"metal::tan({x})"
|
||||
|
|
@ -391,16 +357,19 @@ class MetalOverrides(OpOverrides):
|
|||
|
||||
@staticmethod
|
||||
def rand(seed: CSEVariable, offset: CSEVariable) -> str:
|
||||
V.kernel.headers.add("random")
|
||||
return f"c10::metal::rand({seed}, {offset})"
|
||||
|
||||
@staticmethod
|
||||
def randn(seed: CSEVariable, offset: CSEVariable) -> str:
|
||||
V.kernel.headers.add("random")
|
||||
return f"c10::metal::randn({seed}, {offset})"
|
||||
|
||||
@staticmethod
|
||||
def randint64(
|
||||
seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable
|
||||
) -> str:
|
||||
V.kernel.headers.add("random")
|
||||
return f"c10::metal::randint64({seed}, {offset}, {low}, {high})"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -413,88 +382,75 @@ class MetalOverrides(OpOverrides):
|
|||
cast_b = f"static_cast<decltype({a}+{b})>({b})"
|
||||
return f"metal::pow({cast_a}, {cast_b})"
|
||||
|
||||
@staticmethod
|
||||
def zeta(a: CSEVariable, b: CSEVariable) -> str:
|
||||
return f"c10::metal::zeta({a}, {b})"
|
||||
def _special_unary(self, a: CSEVariable, name: str) -> str:
|
||||
V.kernel.headers.add("special_math")
|
||||
return f"c10::metal::{name}({a})"
|
||||
|
||||
@staticmethod
|
||||
def spherical_bessel_j0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::spherical_bessel_j0({x})"
|
||||
def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str:
|
||||
V.kernel.headers.add("special_math")
|
||||
return f"c10::metal::{name}({a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def xlog1py(x: CSEVariable) -> str:
|
||||
return f"c10::metal::xlog1py({x})"
|
||||
@classmethod
|
||||
def _initialize_special_ops(cls) -> None:
|
||||
# Unary special ops
|
||||
for name in [
|
||||
"erf",
|
||||
"erfinv",
|
||||
"i0",
|
||||
"i0e",
|
||||
"i1",
|
||||
"i1e",
|
||||
"digamma",
|
||||
"spherical_bessel_j0",
|
||||
]:
|
||||
setattr(cls, name, functools.partialmethod(cls._special_unary, name=name))
|
||||
|
||||
@staticmethod
|
||||
def entr(x: CSEVariable) -> str:
|
||||
return f"c10::metal::entr({x})"
|
||||
cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma") # type: ignore[assignment]
|
||||
|
||||
@staticmethod
|
||||
def bessel_j0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::bessel_j0_forward({x})"
|
||||
# Unary special ops with forward in method name
|
||||
for name in [
|
||||
"bessel_j0",
|
||||
"bessel_j1",
|
||||
"bessel_y0",
|
||||
"bessel_y1",
|
||||
"modified_bessel_i0",
|
||||
"modified_bessel_i1",
|
||||
"modified_bessel_k0",
|
||||
"modified_bessel_k1",
|
||||
"scaled_modified_bessel_k0",
|
||||
"scaled_modified_bessel_k1",
|
||||
]:
|
||||
setattr(
|
||||
cls,
|
||||
name,
|
||||
functools.partialmethod(cls._special_unary, name=name + "_forward"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bessel_j1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::bessel_j1_forward({x})"
|
||||
# Binary special ops
|
||||
for name in [
|
||||
"polygamma",
|
||||
"zeta",
|
||||
]:
|
||||
setattr(cls, name, functools.partialmethod(cls._special_binary, name=name))
|
||||
|
||||
@staticmethod
|
||||
def bessel_y0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::bessel_y0_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def bessel_y1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::bessel_y1_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def modified_bessel_i0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::modified_bessel_i0_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def modified_bessel_i1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::modified_bessel_i1_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def modified_bessel_k0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::modified_bessel_k0_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def modified_bessel_k1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::modified_bessel_k1_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def scaled_modified_bessel_k0(x: CSEVariable) -> str:
|
||||
return f"c10::metal::scaled_modified_bessel_k0_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def scaled_modified_bessel_k1(x: CSEVariable) -> str:
|
||||
return f"c10::metal::scaled_modified_bessel_k1_forward({x})"
|
||||
|
||||
@staticmethod
|
||||
def chebyshev_polynomial_t(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::chebyshev_polynomial_t_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def chebyshev_polynomial_u(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::chebyshev_polynomial_u_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def chebyshev_polynomial_v(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::chebyshev_polynomial_v_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::hermite_polynomial_h_forward({x}, {n})"
|
||||
|
||||
@staticmethod
|
||||
def hermite_polynomial_he(x: CSEVariable, n: CSEVariable) -> str:
|
||||
return f"c10::metal::hermite_polynomial_he_forward({x}, {n})"
|
||||
# Binary special ops with forward in method name
|
||||
for name in [
|
||||
"chebyshev_polynomial_t",
|
||||
"chebyshev_polynomial_u",
|
||||
"chebyshev_polynomial_v",
|
||||
"chebyshev_polynomial_w",
|
||||
"hermite_polynomial_h",
|
||||
"hermite_polynomial_he",
|
||||
]:
|
||||
setattr(
|
||||
cls,
|
||||
name,
|
||||
functools.partialmethod(cls._special_binary, name=name + "_forward"),
|
||||
)
|
||||
|
||||
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
MetalOverrides._initialize_special_ops()
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
|
|
@ -508,6 +464,7 @@ class MetalKernel(SIMDKernel):
|
|||
pexpr = PythonPrinter().doprint
|
||||
sexpr = MetalExprPrinter().doprint
|
||||
kexpr = sexpr
|
||||
headers: OrderedSet[str] = OrderedSet(["utils"])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -545,6 +502,7 @@ class MetalKernel(SIMDKernel):
|
|||
if mode is None:
|
||||
line = f"{var}[{self.index_to_str(index)}] = {cast_val};"
|
||||
elif mode == "atomic_add":
|
||||
self.headers.add("atomic")
|
||||
atomic_type = f"c10::metal::AtomicType<{dtype_str}>"
|
||||
cast_var = f"reinterpret_cast<device {atomic_type}::type *>({var})"
|
||||
line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});"
|
||||
|
|
@ -632,6 +590,9 @@ class MetalKernel(SIMDKernel):
|
|||
"threadgroup_barrier(metal::mem_flags::mem_threadgroup);"
|
||||
)
|
||||
return acc
|
||||
|
||||
self.headers.add("reduction_utils")
|
||||
|
||||
if reduction_type in ["prod", "sum"]:
|
||||
acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype]
|
||||
acc_buf = self._new_idxvar(
|
||||
|
|
@ -803,17 +764,8 @@ class MetalKernel(SIMDKernel):
|
|||
code.writeline('compile_mps_shader("""')
|
||||
idx_vars = self.active_range_trees()
|
||||
with code.indent():
|
||||
code.splice(
|
||||
"""
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/random.h>
|
||||
#include <c10/metal/special_math.h>
|
||||
#include <c10/metal/utils.h>
|
||||
""",
|
||||
strip=True,
|
||||
)
|
||||
if self.inside_reduction:
|
||||
code.writeline("#include <c10/metal/reduction_utils.h>")
|
||||
for header in self.headers:
|
||||
code.writeline(f"#include <c10/metal/{header}.h>")
|
||||
if self.inside_reduction:
|
||||
total_reduction_size = math.prod(
|
||||
t.numel for t in self.range_trees if t.is_reduction
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user