mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] slightly faster cholesky (#165867)
Slightly faster cholesky, removed one redundant simdgroup_multiply <img width="721" height="593" alt="Screenshot 2025-10-19 at 22 00 19" src="https://github.com/user-attachments/assets/e3a9005b-9347-4e62-a24d-16ba5e28849a" /> Generate benchmarks with(measured on M1 Pro): ``` import torch import numpy as np import time import csv matrix_sizes = [512, 1024, 2048, 4096] batch_sizes = [1, 2, 4, 8, 16] num_runs = 10 warmup_runs = 3 def create_spd_matrix(n, batch_size): torch.manual_seed(42) A = torch.randn(batch_size, n, n, dtype=torch.float32) return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1) def run_cholesky_mps(A): torch.mps.synchronize() start = time.perf_counter() b = torch.linalg.cholesky(A, upper=False) torch.mps.synchronize() end = time.perf_counter() return b, end - start results = { 'N': [], 'batch_size': [], 'mean_time': [], 'std_time': [] } for n in matrix_sizes: for batch_size in batch_sizes: print(f"\nBenchmarking N={n}, batch_size={batch_size}") try: A_cpu = create_spd_matrix(n, batch_size) A_mps = A_cpu.to("mps") for _ in range(warmup_runs): _, _ = run_cholesky_mps(A_mps) times = [] for _ in range(num_runs): _, t = run_cholesky_mps(A_mps) times.append(t) mean_time = np.mean(times) std_time = np.std(times) results['N'].append(n) results['batch_size'].append(batch_size) results['mean_time'].append(mean_time) results['std_time'].append(std_time) print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s") except RuntimeError as e: print(f"Error for N={n}, batch_size={batch_size}: {e}") continue with open('cholesky_benchmark_times.csv', 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['N', 'batch_size', 'mean_time', 'std_time']) for i in range(len(results['N'])): writer.writerow([ results['N'][i], results['batch_size'][i], results['mean_time'][i], results['std_time'][i] ]) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165867 Approved by: https://github.com/malfet
This commit is contained in:
parent
240c13394e
commit
8f06a1308f
|
|
@ -441,7 +441,7 @@ kernel void applySYRK(
|
|||
uint3 tid [[thread_position_in_threadgroup]],
|
||||
uint3 tgid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threads_per_threadgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]]) {
|
||||
uint warp_id [[simdgroup_index_in_threadgroup]]) {
|
||||
const uint tx = tid.x;
|
||||
const uint ty = tid.y;
|
||||
const uint simdGroupsPerThreadgroup = (tpg.x * tpg.y + 31) / 32;
|
||||
|
|
@ -474,11 +474,8 @@ kernel void applySYRK(
|
|||
(actSize_j % 8 == 0) && (actSize_h % 8 == 0) && (actSize_k % 8 == 0);
|
||||
|
||||
if (use_simdgroup) {
|
||||
uint warp_id = sgitg;
|
||||
|
||||
simdgroup_matrix<float, 8, 8> negative_identity =
|
||||
simdgroup_matrix<float, 8, 8>(-1.0);
|
||||
simdgroup_matrix<float, 8, 8> identity = simdgroup_matrix<float, 8, 8>(1.0);
|
||||
simdgroup_matrix<float, 8, 8> Prod;
|
||||
simdgroup_matrix<float, 8, 8> Afrag;
|
||||
simdgroup_matrix<float, 8, 8> Bfrag;
|
||||
|
|
@ -521,8 +518,7 @@ kernel void applySYRK(
|
|||
/* transpose = */ upper);
|
||||
|
||||
simdgroup_multiply(Prod, Afrag, Bfrag);
|
||||
simdgroup_multiply(Prod, Prod, negative_identity);
|
||||
simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
|
||||
simdgroup_multiply_accumulate(Cfrag, Prod, negative_identity, Cfrag);
|
||||
}
|
||||
|
||||
simdgroup_store(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user