pytorch/torch/csrc/cuda/shared/cusparselt.cpp
Ting Lu c2bc7e2827 API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536)
Changing the bool to int to express split_k_mode. Before 0.7.0 we only have 2 cusparseLtSplitKMode_t enum values ONE_KERNEL and TWO_KERNELS so a boolean is enough but since 0.7.0 there are more.

For Blackwell, there has to be minor change to parameter split_k_one_kernel (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp#L103), since there are new values introduced to enum [cusparseLtSplitKMode_t](https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t) and a bool type is not enough for it (would have to be replaced with integer) https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t

Error we see without the change
```
RuntimeError: CUDA error: invalid value when calling `cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))`

To execute this test, run the following from the base repo dir:
    python test/test_sparse_semi_structured.py TestSparseSemiStructuredCUSPARSELTCUDA.test_csrc_cslt_sparse_mm_search_cuda_int8
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150536
Approved by: https://github.com/jcaip, https://github.com/atalman
2025-05-14 23:36:53 +00:00

53 lines
1.3 KiB
C++

#include <torch/csrc/utils/pybind.h>
#ifdef USE_CUSPARSELT
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
namespace {
size_t getVersionInt() {
return CUSPARSELT_VERSION;
}
std::tuple<int64_t, int64_t, int64_t, int64_t> mmSearch(
const at::Tensor& compressed_A,
const at::Tensor& dense_B,
const std::optional<at::Tensor>& bias_opt,
const std::optional<at::Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result) {
int alg_id_int = 0;
int split_k = 1;
int split_k_mode = -1;
auto result = at::native::_cslt_sparse_mm_impl(
compressed_A,
dense_B,
bias_opt,
alpha_opt,
out_dtype_opt,
transpose_result,
alg_id_int,
split_k,
split_k_mode,
true);
return {
(int64_t)std::get<1>(result),
(int64_t)std::get<2>(result),
(int64_t)std::get<3>(result),
(int64_t)std::get<4>(result)};
}
} // namespace
namespace torch::cuda::shared {
void initCusparseltBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
cusparselt.def("getVersionInt", getVersionInt);
cusparselt.def("mm_search", mmSearch);
}
} // namespace torch::cuda::shared
#endif