mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal alg_id and cache it when running with `torch.compile` Seeing speedups on both bfloat16 and float8 dtypes: <img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b"> <img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6"> * `torch._cslt_sparse_mm_search` has been modified to return optimal split-k parameters as well as max alg_id. * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch |
||
|---|---|---|
| .. | ||
| dlmc | ||
| __init__.py | ||
| README.md | ||
| spmm.py | ||
| spmv.py | ||
| test_csr.sh | ||
| triton_ops.py | ||
| utils.py | ||
#Sparse benchmarks
These sets of benchmarks are for the sparse matrix functionality. They exist for comparing the performance of sparse matrix routines such as SpMV between various sparse matrix formats and with other frameworks such as TensorFlow.