mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Alas, PythonPrinter would not work here, not would CppPrinter, so start building MetalPrinter. `pytest test/inductor/test_torchinductor.py -k _mps` score is 474 failed, 277 passed, 32 skipped Before this change: `pytest test/inductor/test_torchinductor.py -k _mps` reported 506 failed, 245 passed, 32 skipped Pull Request resolved: https://github.com/pytorch/pytorch/pull/143973 Approved by: https://github.com/jansel ghstack dependencies: #143948, #143949
202 lines
6.5 KiB
Python
202 lines
6.5 KiB
Python
# This is not a feature-complete compiler backend
|
|
# Just an early prototype that shows that one can compile elementwise ops into a Metal shader
|
|
from typing import Any, Optional
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_
|
|
|
|
from ..ops_handler import StoreMode
|
|
from ..scheduler import SchedulerNode
|
|
from ..utils import get_kernel_metadata
|
|
from ..virtualized import V
|
|
from .common import CSEVariable, DeferredLine, IndentedBuffer, OpOverrides
|
|
from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling
|
|
|
|
|
|
DTYPE_TO_METAL = {
|
|
torch.bool: "bool",
|
|
torch.int8: "char",
|
|
torch.int16: "short",
|
|
torch.int32: "int",
|
|
torch.int64: "long",
|
|
torch.uint8: "uchar",
|
|
torch.float: "float",
|
|
torch.half: "half",
|
|
torch.bfloat16: "bfloat",
|
|
}
|
|
|
|
|
|
class MetalExprPrinter(ExprPrinter_):
|
|
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
|
x, div = expr.args
|
|
x = self.doprint(x)
|
|
div = self.doprint(div)
|
|
if expr.is_integer:
|
|
return f"({x}) / ({div})"
|
|
return f"metal::floor({x}) / ({div})"
|
|
|
|
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
|
x, div, mod = expr.args
|
|
x = self.doprint(x)
|
|
if div != 1:
|
|
div = self.doprint(div)
|
|
if expr.is_integer:
|
|
x = f"({x}) / ({div})"
|
|
else:
|
|
x = f"metal::floor({x}) / ({div})"
|
|
mod = self.doprint(mod)
|
|
return f"({x}) % ({mod})"
|
|
|
|
|
|
class MetalOverrides(OpOverrides):
|
|
@staticmethod
|
|
def to_dtype(
|
|
x: CSEVariable,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types: bool = True,
|
|
) -> str:
|
|
return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})"
|
|
|
|
@staticmethod
|
|
def where(a: CSEVariable, b: CSEVariable, c: CSEVariable) -> str:
|
|
return f"{a} ? {b} : {c}"
|
|
|
|
@staticmethod
|
|
def logical_or(a: CSEVariable, b: CSEVariable) -> str:
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def logical_and(a: CSEVariable, b: CSEVariable) -> str:
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def abs(x: CSEVariable) -> str:
|
|
return f"metal::abs({x})"
|
|
|
|
@staticmethod
|
|
def sin(x: CSEVariable) -> str:
|
|
return f"metal::precise::sin({x})"
|
|
|
|
@staticmethod
|
|
def cos(x: CSEVariable) -> str:
|
|
return f"metal::precise::cos({x})"
|
|
|
|
@staticmethod
|
|
def tan(x: CSEVariable) -> str:
|
|
return f"metal::tan({x})"
|
|
|
|
@staticmethod
|
|
def asin(x: CSEVariable) -> str:
|
|
return f"metal::asin({x})"
|
|
|
|
@staticmethod
|
|
def acos(x: CSEVariable) -> str:
|
|
return f"metal::acos({x})"
|
|
|
|
@staticmethod
|
|
def atan(x: CSEVariable) -> str:
|
|
return f"metal::atan({x})"
|
|
|
|
@staticmethod
|
|
def sqrt(x: CSEVariable) -> str:
|
|
return f"metal::sqrt({x})"
|
|
|
|
|
|
class MetalKernel(SIMDKernel):
|
|
overrides = MetalOverrides # type: ignore[assignment]
|
|
suffix = ";"
|
|
newvar_prefix = "auto "
|
|
sexpr = MetalExprPrinter().doprint
|
|
|
|
def __init__(
|
|
self,
|
|
tiling: dict[str, sympy.Expr],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(tiling, **kwargs)
|
|
self.compute = self.body
|
|
self.loads = self.body
|
|
self.stores = self.body
|
|
|
|
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
|
return DTYPE_TO_METAL[dtype]
|
|
|
|
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
|
"""Codegen a load from an InputBuffer"""
|
|
var = self.args.input(name)
|
|
index = self.prepare_indexing(index)
|
|
line = f"{var}[{index}]"
|
|
return self.cse.generate(self.body, line)
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
var = self.args.output(name)
|
|
index = self.prepare_indexing(index)
|
|
dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
|
|
line = f"{var}[{index}] = static_cast<{dtype_str}>({value});"
|
|
self.body.writeline(DeferredLine(name, line))
|
|
|
|
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
|
index_expr = self.rename_indexing(entry.expr)
|
|
index_str = self.sexpr(index_expr) # type: ignore[misc]
|
|
self.body.writeline(f"{self.index_dtype} {entry.name} = {index_str};")
|
|
|
|
def codegen_kernel(self, name: Optional[str] = None) -> str:
|
|
"""Called at the end to generate a final kernel string"""
|
|
code = IndentedBuffer()
|
|
code.writeline('torch.mps._compile_shader("""')
|
|
with code.indent():
|
|
code.writeline("kernel void kernel_0(")
|
|
with code.indent():
|
|
for outer, inner in self.args.output_buffers.items():
|
|
if outer in self.removed_buffers:
|
|
continue
|
|
dtype_str = self.dtype_to_str(V.graph.get_dtype(outer))
|
|
code.writeline(f"device {dtype_str}* {inner},")
|
|
for outer, inner in self.args.input_buffers.items():
|
|
dtype_str = self.dtype_to_str(V.graph.get_dtype(outer))
|
|
code.writeline(f"constant {dtype_str}* {inner},")
|
|
code.writeline("uint xindex [[thread_position_in_grid]]")
|
|
code.writeline(") {")
|
|
with code.indent():
|
|
code.splice(self.body)
|
|
code.writeline("}")
|
|
code.writeline('""")')
|
|
|
|
return code.getvalue()
|
|
|
|
def call_kernel(self, name: str, node: Any = None) -> None:
|
|
"""Codegen a call to this kernel"""
|
|
wrapper = V.graph.wrapper_code
|
|
args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()]
|
|
args = [arg for arg in args if arg not in self.removed_buffers]
|
|
wrapper.generate_kernel_call(
|
|
name,
|
|
args,
|
|
gpu=False, # TODO: Fix me, MPS does not expose streams now
|
|
triton=False,
|
|
)
|
|
|
|
|
|
class MetalScheduling(SIMDScheduling):
|
|
kernel_type = MetalKernel # type: ignore[assignment]
|
|
|
|
def define_kernel(
|
|
self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel
|
|
) -> str:
|
|
wrapper = V.graph.wrapper_code
|
|
if src_code in wrapper.src_to_kernel:
|
|
kernel_name = wrapper.src_to_kernel[src_code]
|
|
else:
|
|
kernel_name = f"mps_lib.kernel_{wrapper.next_kernel_suffix()}"
|
|
wrapper.src_to_kernel[src_code] = kernel_name
|
|
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
|
metadata_comment = f"{origins}\n{detailed_origins}"
|
|
wrapper.define_kernel("mps_lib", src_code, metadata_comment)
|
|
|
|
return kernel_name
|