pytorch/test/bench_mps_ops.py
Nikita Shulga ee97299961 [MPS][Testing] Benchmark reduction ops (#150452)
That compares eager vs compile
On my M4Pro mini I'm getting the following now
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      sum (torch.float32)  |      121.0      |       201.5       |       130.3       |        772.3        |       179.4       |        1470.5       |        476.1      |        2980.0
      max (torch.float32)  |      154.1      |       165.9       |       198.7       |        211.6        |       344.2       |         386.9       |       1326.6      |        1345.6
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150452
Approved by: https://github.com/dcci, https://github.com/manuelcandales
2025-04-02 01:06:27 +00:00

130 lines
4.3 KiB
Python

# Owner(s): ["module: mps"]
# Collection of op level benchmarks for MPS
# Useful as reference tool when migrating ops from MPS to Metal
import itertools
import timeit
import warnings
from typing import Optional
import torch
from torch.utils.benchmark import Compare, Measurement, Timer
def bench_unary_op(func, x, label) -> Measurement:
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
t = Timer(
stmt=f"f(x);{sync_cmd}",
globals={"f": func, "x": x},
language="python",
timer=timeit.default_timer,
sub_label=f"{func.__name__} ({str(x.dtype)})",
description=label,
env=torch.__version__,
)
return t.blocked_autorange()
def bench_binary_op(func, x, y, label) -> Measurement:
sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
t = Timer(
stmt=f"f(x, y);{sync_cmd}",
globals={"f": func, "x": x, "y": y},
language="python",
timer=timeit.default_timer,
sub_label=f"{func.__name__} ({str(x.dtype)}, {str(y.dtype)})",
description=label,
env=torch.__version__,
)
return t.blocked_autorange()
def bench_unary(
unary_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dtype)
x_s = torch.testing.make_tensor(1024, 2048, device=device, dtype=dtype)[::, ::2]
rc = []
rc.append(bench_unary_op(unary_func, x, "dense"))
rc.append(bench_unary_op(unary_func, x.t(), "transposed"))
rc.append(bench_unary_op(unary_func, x_s, "strided"))
rc.append(bench_unary_op(unary_func, x_s.t(), "strided + transposed"))
return rc
def bench_binary(
binary_func,
device: str = "mps",
dt_a: torch.dtype = torch.float32,
dt_b: Optional[torch.dtype] = None,
) -> list[Measurement]:
dt_b = dt_b if dt_b is not None else dt_a
x = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_a)
y = torch.testing.make_tensor(1024, 1024, device=device, dtype=dt_b)
s = torch.testing.make_tensor((), device=device, dtype=dt_b)
rc = []
rc.append(bench_binary_op(binary_func, x, y, "dense-dense"))
rc.append(bench_binary_op(binary_func, x.t(), y.t(), "transp-transp"))
rc.append(bench_binary_op(binary_func, x, y.t(), "dense-transp"))
rc.append(bench_binary_op(binary_func, x.t(), y, "transp-dense"))
rc.append(bench_binary_op(binary_func, x, s, "dense-scalar"))
rc.append(bench_binary_op(binary_func, x, y[0], "dense-bcast"))
return rc
def bench_reduction(
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
rc = []
# Bench 2D with reduction over dim=0
def f(t):
return reduction_func(t, dim=0)
f.__name__ = reduction_func.__name__
f_c = torch.compile(f, dynamic=False)
for size in (512, 1024, 2048, 4096):
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
rc_c, rc_e = f(x), f_c(x)
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
return rc
def main() -> None:
dtypes = [torch.float16, torch.float32]
if torch.backends.mps.is_macos_or_newer(14, 0):
dtypes.append(torch.bfloat16)
# Profile unary ops
rc = []
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
rc.extend(bench_unary(op, dtype=dtype))
Compare(rc).print()
# Profile reduction ops
rc = []
for op in [torch.sum, torch.max]:
rc.extend(bench_reduction(op))
Compare(rc).print()
# Profile binary ops
rc = []
ops = [torch.fmax, torch.add]
for op, dtype in itertools.product(ops, dtypes):
rc.extend(bench_binary(op, dt_a=dtype))
if dtype == torch.float32:
rc.extend(bench_binary(op, dt_b=torch.float16))
Compare(rc).print()
if __name__ == "__main__":
main()