mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add benchmark for scan with indices (#156860)
Baseline performance on M4 Max 64GB (macOS 15.5):
```
[-------------------------------- --------------------------------]
| eager | compile
1 threads: ---------------------------------------------------------
cummin-dim0-32x32 (torch.float16) | 102.5 | 115.0
cummin-dim0-128x128 (torch.float16) | 133.6 | 147.8
cummin-dim0-512x512 (torch.float16) | 233.1 | 243.1
cummin-dim0-1024x1024 (torch.float16) | 364.2 | 385.2
cummin-dim1-32x32 (torch.float16) | 94.4 | 109.8
cummin-dim1-128x128 (torch.float16) | 109.9 | 122.5
cummin-dim1-512x512 (torch.float16) | 227.0 | 233.8
cummin-dim1-1024x1024 (torch.float16) | 985.1 | 1010.5
cummin-1d-100 (torch.float16) | 100.7 | 114.3
cummin-1d-10000 (torch.float16) | 805.0 | 879.1
cummin-1d-1000000 (torch.float16) | 70545.6 | 71310.3
cummin-dim0-32x32 (torch.float32) | 102.7 | 115.5
cummin-dim0-128x128 (torch.float32) | 137.2 | 143.8
cummin-dim0-512x512 (torch.float32) | 209.7 | 222.0
cummin-dim0-1024x1024 (torch.float32) | 340.1 | 389.9
cummin-dim1-32x32 (torch.float32) | 99.2 | 107.8
cummin-dim1-128x128 (torch.float32) | 111.9 | 119.3
cummin-dim1-512x512 (torch.float32) | 250.7 | 255.1
cummin-dim1-1024x1024 (torch.float32) | 987.9 | 1013.2
cummin-1d-100 (torch.float32) | 100.6 | 114.6
cummin-1d-10000 (torch.float32) | 794.7 | 862.2
cummin-1d-1000000 (torch.float32) | 71995.3 | 71963.5
cummin-dim0-32x32 (torch.bfloat16) | 105.9 | 113.9
cummin-dim0-128x128 (torch.bfloat16) | 135.7 | 147.9
cummin-dim0-512x512 (torch.bfloat16) | 231.9 | 240.7
cummin-dim0-1024x1024 (torch.bfloat16) | 327.7 | 366.9
cummin-dim1-32x32 (torch.bfloat16) | 91.3 | 103.3
cummin-dim1-128x128 (torch.bfloat16) | 108.5 | 117.4
cummin-dim1-512x512 (torch.bfloat16) | 222.0 | 233.6
cummin-dim1-1024x1024 (torch.bfloat16) | 936.9 | 982.5
cummin-1d-100 (torch.bfloat16) | 106.6 | 112.4
cummin-1d-10000 (torch.bfloat16) | 795.8 | 819.6
cummin-1d-1000000 (torch.bfloat16) | 68667.4 | 68557.9
Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156860
Approved by: https://github.com/malfet
This commit is contained in:
parent
9fe2d156a9
commit
43a09189c6
|
|
@ -71,6 +71,15 @@ def bench_binary(
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
|
|
||||||
|
def check_eager_vs_compile(rc_c, rc_e, func, dtype):
|
||||||
|
if not torch.allclose(rc_c, rc_e):
|
||||||
|
mdiff = (rc_c - rc_e).abs().max()
|
||||||
|
warnings.warn(
|
||||||
|
f"Eager and compile reduction do not match for {func.__name__} and {dtype} max_diff={mdiff}",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bench_reduction(
|
def bench_reduction(
|
||||||
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
||||||
) -> list[Measurement]:
|
) -> list[Measurement]:
|
||||||
|
|
@ -87,19 +96,17 @@ def bench_reduction(
|
||||||
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
||||||
rc_c, rc_e = f(x), f_c(x)
|
rc_c, rc_e = f(x), f_c(x)
|
||||||
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
|
rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
|
||||||
if not torch.allclose(rc_c, rc_e):
|
check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype)
|
||||||
mdiff = (rc_c - rc_e).abs().max()
|
|
||||||
warnings.warn(
|
|
||||||
f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
|
rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
|
||||||
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
|
rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
|
|
||||||
def bench_scan(
|
def bench_scan(
|
||||||
scan_func, device: str = "mps", dtype: torch.dtype = torch.float32
|
scan_func,
|
||||||
|
device: str = "mps",
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
with_indices: bool = False,
|
||||||
) -> list[Measurement]:
|
) -> list[Measurement]:
|
||||||
rc = []
|
rc = []
|
||||||
|
|
||||||
|
|
@ -116,12 +123,11 @@ def bench_scan(
|
||||||
f_c.__name__ = f.__name__
|
f_c.__name__ = f.__name__
|
||||||
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
|
||||||
rc_c, rc_e = f(x), f_c(x)
|
rc_c, rc_e = f(x), f_c(x)
|
||||||
if not torch.allclose(rc_c, rc_e):
|
if with_indices:
|
||||||
mdiff = (rc_c - rc_e).abs().max()
|
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
|
||||||
warnings.warn(
|
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
|
||||||
f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}",
|
else:
|
||||||
stacklevel=2,
|
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
|
||||||
)
|
|
||||||
rc.append(bench_unary_op(f, x, "eager"))
|
rc.append(bench_unary_op(f, x, "eager"))
|
||||||
rc.append(bench_unary_op(f_c, x, "compile"))
|
rc.append(bench_unary_op(f_c, x, "compile"))
|
||||||
|
|
||||||
|
|
@ -136,12 +142,11 @@ def bench_scan(
|
||||||
f_1d_c.__name__ = f_1d.__name__
|
f_1d_c.__name__ = f_1d.__name__
|
||||||
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
|
x = torch.testing.make_tensor(size, device=device, dtype=dtype)
|
||||||
rc_c, rc_e = f_1d(x), f_1d_c(x)
|
rc_c, rc_e = f_1d(x), f_1d_c(x)
|
||||||
if not torch.allclose(rc_c, rc_e):
|
if with_indices:
|
||||||
mdiff = (rc_c - rc_e).abs().max()
|
check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype)
|
||||||
warnings.warn(
|
check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype)
|
||||||
f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}",
|
else:
|
||||||
stacklevel=2,
|
check_eager_vs_compile(rc_c, rc_e, scan_func, dtype)
|
||||||
)
|
|
||||||
rc.append(bench_unary_op(f_1d, x, "eager"))
|
rc.append(bench_unary_op(f_1d, x, "eager"))
|
||||||
rc.append(bench_unary_op(f_1d_c, x, "compile"))
|
rc.append(bench_unary_op(f_1d_c, x, "compile"))
|
||||||
|
|
||||||
|
|
@ -171,6 +176,12 @@ def main() -> None:
|
||||||
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
|
rc.extend(bench_scan(torch.cumsum, dtype=dtype))
|
||||||
Compare(rc).print()
|
Compare(rc).print()
|
||||||
|
|
||||||
|
# Profile scan with indices ops (cummin)
|
||||||
|
rc = []
|
||||||
|
for dtype in dtypes:
|
||||||
|
rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True))
|
||||||
|
Compare(rc).print()
|
||||||
|
|
||||||
# Profile binary ops
|
# Profile binary ops
|
||||||
rc = []
|
rc = []
|
||||||
ops = [torch.fmax, torch.add]
|
ops = [torch.fmax, torch.add]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user