mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal alg_id and cache it when running with `torch.compile` Seeing speedups on both bfloat16 and float8 dtypes: <img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b"> <img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6"> * `torch._cslt_sparse_mm_search` has been modified to return optimal split-k parameters as well as max alg_id. * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch
53 lines
1.3 KiB
C++
53 lines
1.3 KiB
C++
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#ifdef USE_CUSPARSELT
|
|
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
|
|
|
namespace {
|
|
|
|
size_t getVersionInt() {
|
|
return CUSPARSELT_VERSION;
|
|
}
|
|
|
|
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
|
const at::Tensor& compressed_A,
|
|
const at::Tensor& dense_B,
|
|
const std::optional<at::Tensor>& bias_opt,
|
|
const std::optional<at::Tensor>& alpha_opt,
|
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
|
bool transpose_result) {
|
|
int alg_id_int = 0;
|
|
int split_k = 1;
|
|
bool split_k_one_kernel = true;
|
|
auto result = at::native::_cslt_sparse_mm_impl(
|
|
compressed_A,
|
|
dense_B,
|
|
bias_opt,
|
|
alpha_opt,
|
|
out_dtype_opt,
|
|
transpose_result,
|
|
alg_id_int,
|
|
split_k,
|
|
split_k_one_kernel,
|
|
true);
|
|
return {
|
|
(int64_t)std::get<1>(result),
|
|
(int64_t)std::get<2>(result),
|
|
(bool)std::get<3>(result),
|
|
(int64_t)std::get<4>(result)};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace torch::cuda::shared {
|
|
|
|
void initCusparseltBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
|
cusparselt.def("getVersionInt", getVersionInt);
|
|
cusparselt.def("mm_search", mmSearch);
|
|
}
|
|
|
|
} // namespace torch::cuda::shared
|
|
#endif
|