Summary:
This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`
Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">
* `torch._cslt_sparse_mm_search` has been modified to return optimal
split-k parameters as well as max alg_id.
* max_id is now available in `torch.backends.cusparselt` via
`torch.backends.cusparselt.get_max_alg_id()`
* fixed meta registrations for float8
Test Plan:
python test/test_sparse_semi_structured.py
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
Summary:
This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.
This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.
```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)
dense_result = torch._scaled_mm(
A_fp8, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
A_fp8_sparse, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
```
Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.
I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner
Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
Summary:
This PR adds in cuSPARSELt as a backend to PyTorch.
It is now possible to see if cuSPARSELt is available and the version if
it is with
```
torch.backends.cusparselt.is_available()
torch.backends.cusparselt.version()
```
Test Plan:
```
python test/test_sparse_semi_structured.py -k test_cusparselt_backend
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534
Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed