From 43a09189c68fe02bd9d8433c4a144ffc9bbf895c Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 25 Jun 2025 14:51:32 -0400 Subject: [PATCH] [MPS] Add benchmark for scan with indices (#156860) Baseline performance on M4 Max 64GB (macOS 15.5): ``` [-------------------------------- --------------------------------] | eager | compile 1 threads: --------------------------------------------------------- cummin-dim0-32x32 (torch.float16) | 102.5 | 115.0 cummin-dim0-128x128 (torch.float16) | 133.6 | 147.8 cummin-dim0-512x512 (torch.float16) | 233.1 | 243.1 cummin-dim0-1024x1024 (torch.float16) | 364.2 | 385.2 cummin-dim1-32x32 (torch.float16) | 94.4 | 109.8 cummin-dim1-128x128 (torch.float16) | 109.9 | 122.5 cummin-dim1-512x512 (torch.float16) | 227.0 | 233.8 cummin-dim1-1024x1024 (torch.float16) | 985.1 | 1010.5 cummin-1d-100 (torch.float16) | 100.7 | 114.3 cummin-1d-10000 (torch.float16) | 805.0 | 879.1 cummin-1d-1000000 (torch.float16) | 70545.6 | 71310.3 cummin-dim0-32x32 (torch.float32) | 102.7 | 115.5 cummin-dim0-128x128 (torch.float32) | 137.2 | 143.8 cummin-dim0-512x512 (torch.float32) | 209.7 | 222.0 cummin-dim0-1024x1024 (torch.float32) | 340.1 | 389.9 cummin-dim1-32x32 (torch.float32) | 99.2 | 107.8 cummin-dim1-128x128 (torch.float32) | 111.9 | 119.3 cummin-dim1-512x512 (torch.float32) | 250.7 | 255.1 cummin-dim1-1024x1024 (torch.float32) | 987.9 | 1013.2 cummin-1d-100 (torch.float32) | 100.6 | 114.6 cummin-1d-10000 (torch.float32) | 794.7 | 862.2 cummin-1d-1000000 (torch.float32) | 71995.3 | 71963.5 cummin-dim0-32x32 (torch.bfloat16) | 105.9 | 113.9 cummin-dim0-128x128 (torch.bfloat16) | 135.7 | 147.9 cummin-dim0-512x512 (torch.bfloat16) | 231.9 | 240.7 cummin-dim0-1024x1024 (torch.bfloat16) | 327.7 | 366.9 cummin-dim1-32x32 (torch.bfloat16) | 91.3 | 103.3 cummin-dim1-128x128 (torch.bfloat16) | 108.5 | 117.4 cummin-dim1-512x512 (torch.bfloat16) | 222.0 | 233.6 cummin-dim1-1024x1024 (torch.bfloat16) | 936.9 | 982.5 cummin-1d-100 (torch.bfloat16) | 106.6 | 112.4 cummin-1d-10000 (torch.bfloat16) | 795.8 | 819.6 cummin-1d-1000000 (torch.bfloat16) | 68667.4 | 68557.9 Times are in microseconds (us). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156860 Approved by: https://github.com/malfet --- test/bench_mps_ops.py | 49 ++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 7bc2da54559..a7d647c120f 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -71,6 +71,15 @@ def bench_binary( return rc +def check_eager_vs_compile(rc_c, rc_e, func, dtype): + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile reduction do not match for {func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) + + def bench_reduction( reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32 ) -> list[Measurement]: @@ -87,19 +96,17 @@ def bench_reduction( x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) rc_c, rc_e = f(x), f_c(x) rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e) - if not torch.allclose(rc_c, rc_e): - mdiff = (rc_c - rc_e).abs().max() - warnings.warn( - f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}", - stacklevel=2, - ) + check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype) rc.append(bench_unary_op(f, x, f"eager-{size}x{size}")) rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}")) return rc def bench_scan( - scan_func, device: str = "mps", dtype: torch.dtype = torch.float32 + scan_func, + device: str = "mps", + dtype: torch.dtype = torch.float32, + with_indices: bool = False, ) -> list[Measurement]: rc = [] @@ -116,12 +123,11 @@ def bench_scan( f_c.__name__ = f.__name__ x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) rc_c, rc_e = f(x), f_c(x) - if not torch.allclose(rc_c, rc_e): - mdiff = (rc_c - rc_e).abs().max() - warnings.warn( - f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}", - stacklevel=2, - ) + if with_indices: + check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) + check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) + else: + check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) rc.append(bench_unary_op(f, x, "eager")) rc.append(bench_unary_op(f_c, x, "compile")) @@ -136,12 +142,11 @@ def bench_scan( f_1d_c.__name__ = f_1d.__name__ x = torch.testing.make_tensor(size, device=device, dtype=dtype) rc_c, rc_e = f_1d(x), f_1d_c(x) - if not torch.allclose(rc_c, rc_e): - mdiff = (rc_c - rc_e).abs().max() - warnings.warn( - f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}", - stacklevel=2, - ) + if with_indices: + check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) + check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) + else: + check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) rc.append(bench_unary_op(f_1d, x, "eager")) rc.append(bench_unary_op(f_1d_c, x, "compile")) @@ -171,6 +176,12 @@ def main() -> None: rc.extend(bench_scan(torch.cumsum, dtype=dtype)) Compare(rc).print() + # Profile scan with indices ops (cummin) + rc = [] + for dtype in dtypes: + rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True)) + Compare(rc).print() + # Profile binary ops rc = [] ops = [torch.fmax, torch.add]