[MPS] Add benchmark for scan with indices (#156860)

Baseline performance on M4 Max 64GB (macOS 15.5):
```
[--------------------------------  --------------------------------]
                                              |   eager   |  compile
1 threads: ---------------------------------------------------------
      cummin-dim0-32x32 (torch.float16)       |    102.5  |    115.0
      cummin-dim0-128x128 (torch.float16)     |    133.6  |    147.8
      cummin-dim0-512x512 (torch.float16)     |    233.1  |    243.1
      cummin-dim0-1024x1024 (torch.float16)   |    364.2  |    385.2
      cummin-dim1-32x32 (torch.float16)       |     94.4  |    109.8
      cummin-dim1-128x128 (torch.float16)     |    109.9  |    122.5
      cummin-dim1-512x512 (torch.float16)     |    227.0  |    233.8
      cummin-dim1-1024x1024 (torch.float16)   |    985.1  |   1010.5
      cummin-1d-100 (torch.float16)           |    100.7  |    114.3
      cummin-1d-10000 (torch.float16)         |    805.0  |    879.1
      cummin-1d-1000000 (torch.float16)       |  70545.6  |  71310.3
      cummin-dim0-32x32 (torch.float32)       |    102.7  |    115.5
      cummin-dim0-128x128 (torch.float32)     |    137.2  |    143.8
      cummin-dim0-512x512 (torch.float32)     |    209.7  |    222.0
      cummin-dim0-1024x1024 (torch.float32)   |    340.1  |    389.9
      cummin-dim1-32x32 (torch.float32)       |     99.2  |    107.8
      cummin-dim1-128x128 (torch.float32)     |    111.9  |    119.3
      cummin-dim1-512x512 (torch.float32)     |    250.7  |    255.1
      cummin-dim1-1024x1024 (torch.float32)   |    987.9  |   1013.2
      cummin-1d-100 (torch.float32)           |    100.6  |    114.6
      cummin-1d-10000 (torch.float32)         |    794.7  |    862.2
      cummin-1d-1000000 (torch.float32)       |  71995.3  |  71963.5
      cummin-dim0-32x32 (torch.bfloat16)      |    105.9  |    113.9
      cummin-dim0-128x128 (torch.bfloat16)    |    135.7  |    147.9
      cummin-dim0-512x512 (torch.bfloat16)    |    231.9  |    240.7
      cummin-dim0-1024x1024 (torch.bfloat16)  |    327.7  |    366.9
      cummin-dim1-32x32 (torch.bfloat16)      |     91.3  |    103.3
      cummin-dim1-128x128 (torch.bfloat16)    |    108.5  |    117.4
      cummin-dim1-512x512 (torch.bfloat16)    |    222.0  |    233.6
      cummin-dim1-1024x1024 (torch.bfloat16)  |    936.9  |    982.5
      cummin-1d-100 (torch.bfloat16)          |    106.6  |    112.4
      cummin-1d-10000 (torch.bfloat16)        |    795.8  |    819.6
      cummin-1d-1000000 (torch.bfloat16)      |  68667.4  |  68557.9

Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156860
Approved by: https://github.com/malfet
This commit is contained in:
Manuel Candales 2025-06-25 14:51:32 -04:00 committed by PyTorch MergeBot
parent 9fe2d156a9
commit 43a09189c6

View File

@ -71,6 +71,15 @@ def bench_binary(
return rc
def check_eager_vs_compile(rc_c, rc_e, func, dtype):
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile reduction do not match for {func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
def bench_reduction(
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
@ -87,19 +96,17 @@ def bench_reduction(
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
rc_c, rc_e = f(x), f_c(x)
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype)
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
return rc
def bench_scan(
scan_func, device: str = "mps", dtype: torch.dtype = torch.float32
scan_func,
device: str = "mps",
dtype: torch.dtype = torch.float32,
with_indices: bool = False,
) -> list[Measurement]:
rc = []
@ -116,12 +123,11 @@ def bench_scan(
f_c.__name__ = f.__name__
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
rc_c, rc_e = f(x), f_c(x)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
if with_indices:
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
else:
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
rc.append(bench_unary_op(f, x, "eager"))
rc.append(bench_unary_op(f_c, x, "compile"))
@ -136,12 +142,11 @@ def bench_scan(
f_1d_c.__name__ = f_1d.__name__
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
rc_c, rc_e = f_1d(x), f_1d_c(x)
if not torch.allclose(rc_c, rc_e):
mdiff = (rc_c - rc_e).abs().max()
warnings.warn(
f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}",
stacklevel=2,
)
if with_indices:
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
else:
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
rc.append(bench_unary_op(f_1d, x, "eager"))
rc.append(bench_unary_op(f_1d_c, x, "compile"))
@ -171,6 +176,12 @@ def main() -> None:
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
Compare(rc).print()
# Profile scan with indices ops (cummin)
rc = []
for dtype in dtypes:
rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True))
Compare(rc).print()
# Profile binary ops
rc = []
ops = [torch.fmax, torch.add]