diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 01e090b2d37..3625cd87124 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3357,7 +3357,7 @@ dispatch: CUDA: _cslt_compress -- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor +- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor dispatch: CUDA: _cslt_sparse_mm diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 8fb56ec40a7..ca3996f00e7 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -1,7 +1,20 @@ -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #if AT_CUSPARSELT_ENABLED() +#include + namespace at::native { // Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp @@ -43,7 +56,6 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 case at::ScalarType::Float8_e4m3fn: type = CUDA_R_8F_E4M3; - compression_factor = 10; break; #endif default: @@ -91,7 +103,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) return compressed_tensor; } -std::tuple _cslt_sparse_mm_impl( +std::tuple _cslt_sparse_mm_impl( const Tensor& compressed_A, const Tensor& dense_B, const std::optional& bias_opt, @@ -99,8 +111,6 @@ std::tuple _cslt_sparse_mm_impl( const std::optional out_dtype_opt, bool transpose_result, int alg_id, - int split_k, - bool split_k_one_kernel, bool search_alg_id ) { @@ -159,7 +169,6 @@ std::tuple _cslt_sparse_mm_impl( output_type = CUDA_R_8F_E4M3; C_type = CUDA_R_16F; compute_type = CUSPARSE_COMPUTE_32F; - compression_factor = 10; break; #endif // cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F @@ -326,21 +335,10 @@ std::tuple _cslt_sparse_mm_impl( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit( &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)); - // set matmul search params + // set alg_id TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); - cusparseLtSplitKMode_t splitKMode; - int max_alg_id; - if (split_k != 1) { - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( - &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K, &split_k, sizeof(split_k))); - - splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL : CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS; - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( - &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))); - } - // set tensor_alpha_mode and alpha pointer for matmul const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{}; auto alpha_ptr = α @@ -383,23 +381,9 @@ std::tuple _cslt_sparse_mm_impl( &stream, 1)); - // get matmul params used + // get alg_id used TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); - - TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K, - &split_k, sizeof(split_k))); - - TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K_MODE, - &splitKMode, sizeof(splitKMode))); - - TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel, - CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID, - &max_alg_id, sizeof(max_alg_id))); - - } else { // do normal matmul @@ -427,7 +411,7 @@ std::tuple _cslt_sparse_mm_impl( // destroy plan TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan)); - return {res, alg_id, split_k, splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL, max_alg_id}; + return {alg_id, res}; } at::Tensor _cslt_sparse_mm( @@ -437,9 +421,7 @@ at::Tensor _cslt_sparse_mm( const std::optional& alpha_opt, const std::optional out_dtype_opt, bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel + int64_t alg_id ) { auto result = _cslt_sparse_mm_impl( @@ -450,10 +432,8 @@ at::Tensor _cslt_sparse_mm( out_dtype_opt, transpose_result, (int) alg_id, - (int) split_k, - split_k_one_kernel, false); - return std::get<0>(result); + return std::get<1>(result); } int64_t _cslt_sparse_mm_search( @@ -465,10 +445,7 @@ int64_t _cslt_sparse_mm_search( bool transpose_result ) { - TORCH_WARN_ONCE("torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead."); int alg_id_int = 0; - int split_k = 1; - bool split_k_one_kernel= true; auto result = _cslt_sparse_mm_impl( compressed_A, dense_B, @@ -477,12 +454,11 @@ int64_t _cslt_sparse_mm_search( out_dtype_opt, transpose_result, alg_id_int, - split_k, - split_k_one_kernel, true); - return (int64_t) std::get<1>(result); + return (int64_t) std::get<0>(result); } + } // namespace at::native #else // No cuSPARSELt support, throw error if these functions are called. @@ -500,9 +476,7 @@ at::Tensor _cslt_sparse_mm( const std::optional& alpha_opt, const std::optional out_dtype, bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel) + int64_t alg_id) { TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); } diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h deleted file mode 100644 index 00e7a8e1477..00000000000 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if AT_CUSPARSELT_ENABLED() -#include -#endif - -namespace at::native { - -at::Tensor _cslt_compress(const Tensor& sparse_input); - -TORCH_CUDA_CPP_API std::tuple _cslt_sparse_mm_impl( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result, - int alg_id, - int split_k, - bool split_k_one_kernel, - bool search_alg_id -); - -at::Tensor _cslt_sparse_mm( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel -); - -int64_t _cslt_sparse_mm_search( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result -); - -} // namespace at::native diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py new file mode 100644 index 00000000000..66311c40428 --- /dev/null +++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py @@ -0,0 +1,253 @@ +import argparse +import random + +import pandas as pd +from tqdm import tqdm + +import torch +import torch.utils.benchmark as benchmark +from torch import nn +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured + + +torch.set_printoptions( + precision=2, + threshold=None, + edgeitems=16, + linewidth=480, + profile=None, + sci_mode=False, +) + + +# helper model definition for pruner +class Model(nn.Module): + def __init__(self, m, k, dtype=None): + super().__init__() + # transposed so reversed + self.linear = nn.Linear(k, m) + + def forward(self, x): + return self.linear(x) + + +def rand_sparse_semi_structured_mask( + r, c, dtype=torch.float16, device="cuda", choice=None +): + """ + This function returns a 1:2 sparse matrix of size (r, c). + Note that this means this matrix will also be 2:4 and 4:8 sparse as well. + """ + + choices = [[0, 1], [1, 0]] + mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] + + return ( + torch.tensor(mask_entries, dtype=dtype, device=device) + .reshape(r, c) + .contiguous() + ) + + +def test_linear(m, k, n, dtype, contiguous, backend): + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" + mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask + input_tensor = torch.zeros(n, k).to(dtype).cuda() + model = Model(m, k).to(dtype).cuda().eval() + + dense_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + dense_output = model(input_tensor) + print(dense_output.shape) + + # sparsify weights + model.linear.weight = nn.Parameter( + to_sparse_semi_structured( + sparse_weight, + ) + ) + + sparse_output = model(input_tensor) + print(sparse_output.shape) + + sparse_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "linear", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +def test_tensor(m, k, n, dtype, contiguous, backend): + A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + B = torch.zeros(k, n).to(dtype).cuda() + bias = torch.rand(n).to(dtype).cuda() + + sA = to_sparse_semi_structured(A) + + # torch.mm calculation + if dtype is not torch.int8: + dense_output = torch.mm(A, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(A, B)", + globals=locals(), + ).blocked_autorange() + + else: + print("int8 baseline not supported") + dense_output = torch.mm(sA, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + sparse_output = torch.mm(sA, B) + sparse_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "tensor", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +if __name__ == "__main__": + dtype_lookup = { + "int8": torch.int8, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + + parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") + parser.add_argument( + "--mode", + type=str, + choices=[ + "nvidia-bert", + "nvidia-fixed-k", + "nvidia-fixed-mn", + ], + ) + parser.add_argument( + "--dtype", + type=str, + choices=dtype_lookup.keys(), + default="fp16", + ) + parser.add_argument( + "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" + ) + parser.add_argument("-contiguous", action="store_true") + parser.add_argument("-e2e", action="store_true") + parser.add_argument("-save", action="store_true") + args = parser.parse_args() + + if args.e2e: + eval_fn = test_linear + else: + eval_fn = test_tensor + + print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") + dtype = dtype_lookup[args.dtype] + + if args.mode == "nvidia-bert": + bert_shapes = [ + (3072, 1024, 16384), + (4096, 1024, 16384), + (1024, 1024, 16384), + (1024, 4096, 16384), + ] + results = ( + eval_fn(m, k, n, dtype, args.contiguous, args.backend) + for (m, k, n) in tqdm(bert_shapes) + ) + + elif args.mode == "nvidia-fixed-k": + mn_vals = [ + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + ] + results = ( + eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) + for mn in tqdm(mn_vals) + ) + + elif args.mode == "nvidia-fixed-mn": + k_vals = [ + 2560, + 3840, + 5120, + 6400, + 7680, + 8960, + 10240, + 11520, + 12800, + 14080, + 15360, + 16640, + 17920, + 19200, + 20480, + ] + results = ( + eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) + for k in tqdm(k_vals) + ) + + df = pd.DataFrame.from_records(results) + if args.save: + save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" + df.to_csv(save_file) + print(f"Finished benchmark: {args.mode} saved results to {save_file}") + print(df) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 52af386cd2b..2292dca8c97 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -244,17 +244,18 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") def test_sp24_compile(self) -> None: x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) + e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) - def fn(x): + def fn(x, e): y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) y = y.t() return x @ y # Eager - output = fn(x) + output = fn(x, e) output.backward(output) # Torch compile - output = torch.compile(fn)(x) + output = torch.compile(fn)(x, e) output.backward(output) class TestSparseSemiStructured(TestCase): @@ -1155,9 +1156,8 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): B = torch.ones((128, 128), device=device).to(dtype) A_compressed = torch._cslt_compress(A) - alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False) - sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), - alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel) + alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) + sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) dense_result = dense_result.to(dtype) @@ -1174,16 +1174,6 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) assert alg_id in range(torch.backends.cusparselt.get_max_alg_id()) - @inference_dtypes - def test_csrc_cslt_sparse_mm_search(self, device, dtype): - A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) - A_compressed = torch._cslt_compress(A) - B = torch.ones((128, 128), device=device).to(dtype) - - A_compressed = torch._cslt_compress(A) - alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False) - assert alg_id in range(torch.backends.cusparselt.get_max_alg_id()) - def test_cusparselt_backend(self): version = _get_torch_cuda_version() assert torch.backends.cusparselt.is_available() @@ -1191,11 +1181,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): # CUDA 11.8 has cuSPARSELt v0.4.0 support if version == (11, 8): assert torch.backends.cusparselt.version() == 400 - assert torch.backends.cusparselt.get_max_alg_id() == 4 # CUDA 12.1 has cuSPARSELt v0.5.2 support elif version == (12, 1): assert torch.backends.cusparselt.version() == 502 - assert torch.backends.cusparselt.get_max_alg_id() == 4 # CUDA 12.4+ has cuSPARSELt v0.6.2 support elif version >= (12, 4): assert torch.backends.cusparselt.version() == 602 diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 56107712a7a..d7aed0214e9 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm @@ -144,12 +144,6 @@ aten__sparse_semi_structured_mm = ExternKernelChoice( has_out_variant=False, ) -aten__cslt_sparse_mm = ExternKernelChoice( - torch._cslt_sparse_mm, - "at::_cslt_sparse_mm", - has_out_variant=False, -) - def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) @@ -527,124 +521,6 @@ def tuned_sparse_semi_structured_mm( ) -@register_lowering(aten._cslt_sparse_mm, type_promotion_kind=None) -def tuned_cslt_sparse_mm( - mat1_compressed, - mat2, - bias=None, - alpha=None, - out_dtype=None, - transpose_result=False, - alg_id=0, - split_k=1, - split_k_one_kernel=True, - layout=None, -): - from torch._inductor.select_algorithm import AlgorithmSelectorCache, realize_inputs - - mat1_compressed, mat2 = realize_inputs(mat1_compressed, mat2) - input_nodes: Tuple[Any, ...] = (mat1_compressed, mat2) - k, n = mat2.get_size() - - is_8bit_input_type = mat1_compressed.dtype in [torch.int8, torch.float8_e4m3fn] - compression_factor = 10 if is_8bit_input_type else 9 - m = (mat1_compressed.get_numel() * 16) // (compression_factor * k) - - from torch._inductor.ir import FixedLayout - - if transpose_result: - layout = FixedLayout( - mat2.get_device(), - out_dtype if out_dtype else mat2.get_dtype(), - [n, m], - [m, 1], - ) - else: - layout = FixedLayout( - mat2.get_device(), - out_dtype if out_dtype else mat2.get_dtype(), - [m, n], - [n, 1], - ) - # workaround for Inductor not supporting optional tensor input arguments - if bias is not None: - bias = realize_inputs(bias) - input_nodes = input_nodes + (bias,) - - if alpha is not None: - alpha = realize_inputs(alpha) - input_nodes = input_nodes + (alpha,) - - # cuSPARSELt alg_id search, not that we cannot use - # AlgorithmSelectorCache.benchmark_example_value() because this will return the base view - # and mat2 needs to have transpose properties preserved for cslt mm - ( - searched_alg_id, - searched_split_k, - searched_split_k_one_kernel, - _, - ) = torch._C._cusparselt.mm_search( # type: ignore[attr-defined] - AlgorithmSelectorCache.generate_example_value( - V.graph.sizevars.size_hints(mat1_compressed.get_size()), - V.graph.sizevars.size_hints(mat1_compressed.get_stride()), - mat1_compressed.get_device(), - mat1_compressed.dtype, - mat1_compressed.layout.offset, - ), - AlgorithmSelectorCache.generate_example_value( - V.graph.sizevars.size_hints(mat2.get_size()), - V.graph.sizevars.size_hints(mat2.get_stride()), - mat2.get_device(), - mat2.dtype, - mat2.layout.offset, - ), - AlgorithmSelectorCache.generate_example_value( - V.graph.sizevars.size_hints(bias.get_size()), - V.graph.sizevars.size_hints(bias.get_stride()), - bias.get_device(), - bias.dtype, - bias.layout.offset, - ) - if bias is not None - else None, - AlgorithmSelectorCache.generate_example_value( - V.graph.sizevars.size_hints(alpha.get_size()), - V.graph.sizevars.size_hints(alpha.get_stride()), - alpha.get_device(), - alpha.dtype, - alpha.layout.offset, - ) - if alpha is not None - else None, - out_dtype, - transpose_result, - ) - - baseline = aten__cslt_sparse_mm.bind( - input_nodes, - layout, - out_dtype=out_dtype, - alg_id=0, - split_k=1, - split_k_one_kernel=True, - transpose_result=transpose_result, - ) - baseline.description = f"ALG_ID: 0 SPLIT_K: 1 SPLIT_K_ONE_KERNEL: True TRANSPOSE_RESULT: {transpose_result}" - searched = aten__cslt_sparse_mm.bind( - input_nodes, - layout, - out_dtype=out_dtype, - alg_id=searched_alg_id, - split_k=searched_split_k, - split_k_one_kernel=searched_split_k_one_kernel, - transpose_result=transpose_result, - ) - searched.description = f"ALG_ID: {searched_alg_id} SPLIT_K: {searched_split_k} SPLIT_K_ONE_KERNEL: {searched_split_k_one_kernel} TRANSPOSE_RESULT: {transpose_result}" # noqa: B950 - choices = [baseline, searched] - - return autotune_select_algorithm("cslt_sparse_mm", choices, input_nodes, layout) - - def fallback_mixed_mm(mat1, mat2, *, out): return torch.mm(mat1, mat2.to(mat1.dtype), out=out) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index a022a75e804..0da6b58bdb4 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -520,44 +520,32 @@ def meta__cslt_sparse_mm( alpha: Optional[Tensor] = None, out_dtype: Optional[torch.dtype] = None, transpose_result: bool = False, - alg_id: int = 0, - split_k: int = 1, - split_k_one_kernel: bool = False, ): assert dense_B.dtype in { torch.float32, torch.float16, torch.bfloat16, torch.int8, - torch.float8_e4m3fn, - }, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3" + }, "_cslt_sparse_mm only supports fp16, bf16, and int8" assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" - is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn] - compression_factor = 10 if is_8bit_input_type else 9 + is_int8_input_type = compressed_A.dtype == torch.int8 + compression_factor = 10 if is_int8_input_type else 9 k = dense_B.size(0) n = dense_B.size(1) m = (compressed_A.numel() * 16) // (compression_factor * k) if bias is not None: assert m == bias.size(0) - if is_8bit_input_type: - assert not dense_B.is_contiguous() - if out_dtype is not None: - assert ( - is_8bit_input_type - and out_dtype - in { - torch.float16, - torch.bfloat16, - torch.int32, - torch.float8_e4m3fn, - } - ), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!" + assert is_int8_input_type and out_dtype in { + torch.float16, + torch.bfloat16, + torch.int32, + }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul" output_shape = (n, m) if transpose_result else (m, n) - result = torch.empty(output_shape, dtype=out_dtype, device=compressed_A.device) + result = dense_B.new_empty(output_shape, dtype=out_dtype) return result diff --git a/torch/backends/cusparselt/__init__.py b/torch/backends/cusparselt/__init__.py index ebd33636f55..da46274a284 100644 --- a/torch/backends/cusparselt/__init__.py +++ b/torch/backends/cusparselt/__init__.py @@ -25,12 +25,12 @@ if _cusparselt is not None: global __MAX_ALG_ID if __cusparselt_version is None: __cusparselt_version = _cusparselt.getVersionInt() - - # only way to get MAX_ALG_ID is to run a matmul - A = torch.zeros(128, 128, dtype=torch.float16).cuda() - A = torch._cslt_compress(A) - B = torch.zeros(128, 128, dtype=torch.float16).cuda() - _, _, _, __MAX_ALG_ID = _cusparselt.mm_search(A, B, None, None, None, False) # type: ignore[attr-defined] + if __cusparselt_version == 400: + __MAX_ALG_ID = 4 + elif __cusparselt_version == 502: + __MAX_ALG_ID = 5 + elif __cusparselt_version == 602: + __MAX_ALG_ID = 37 return True else: @@ -52,7 +52,6 @@ def is_available() -> bool: def get_max_alg_id() -> Optional[int]: - r"""Return the maximum algorithm id supported by the current version of cuSPARSELt""" if not _init(): return None return __MAX_ALG_ID diff --git a/torch/csrc/cuda/shared/cusparselt.cpp b/torch/csrc/cuda/shared/cusparselt.cpp index 02be708e913..ca020b75a70 100644 --- a/torch/csrc/cuda/shared/cusparselt.cpp +++ b/torch/csrc/cuda/shared/cusparselt.cpp @@ -1,7 +1,7 @@ #include #ifdef USE_CUSPARSELT -#include +#include namespace { @@ -9,34 +9,6 @@ size_t getVersionInt() { return CUSPARSELT_VERSION; } -std::tuple mmSearch( - const at::Tensor& compressed_A, - const at::Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result) { - int alg_id_int = 0; - int split_k = 1; - bool split_k_one_kernel = true; - auto result = at::native::_cslt_sparse_mm_impl( - compressed_A, - dense_B, - bias_opt, - alpha_opt, - out_dtype_opt, - transpose_result, - alg_id_int, - split_k, - split_k_one_kernel, - true); - return { - (int64_t)std::get<1>(result), - (int64_t)std::get<2>(result), - (bool)std::get<3>(result), - (int64_t)std::get<4>(result)}; -} - } // namespace namespace torch::cuda::shared { @@ -45,7 +17,6 @@ void initCusparseltBindings(PyObject* module) { auto m = py::handle(module).cast(); auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings"); cusparselt.def("getVersionInt", getVersionInt); - cusparselt.def("mm_search", mmSearch); } } // namespace torch::cuda::shared diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 9a5fdca947a..eb5557bf8b0 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -103,8 +103,6 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor: packed_t=self.packed_t, meta_t=self.meta_t, compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, - fuse_transpose_cusparselt=self.fuse_transpose_cusparselt, - alg_id_cusparselt=self.alg_id_cusparselt, requires_grad=False, ) @@ -179,37 +177,19 @@ def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor: assert A.dtype == torch.float8_e4m3fn assert B.dtype == torch.float8_e4m3fn - # cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result. - # Note that this limits us to per-tensor scalig only. + # only cuSPARSELt supports float8_e4m3fn currentl + assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT) + assert A.packed is not None + # Currently we only support per-tensor scaling, with float32 scales assert A_scale.numel() == 1 and B_scale.numel() == 1 assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32 - # only cuSPARSELt supports float8_e4m3fn currentl - if isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT): - assert A.packed is not None - row, col = B.shape - B_padded = A._pad_dense_input(B).contiguous().t() - sparse_result = torch._cslt_sparse_mm( - A.packed, - B_padded, - alpha=A_scale * B_scale, - out_dtype=out_dtype, - bias=bias, - ) - return sparse_result[:, :col] - else: - assert isinstance(B, torch.sparse.SparseSemiStructuredTensor) - assert B.packed is not None - row, col = A.shape - A_padded = B._pad_dense_input(A) - sparse_result = torch._cslt_sparse_mm( - B.packed, - A_padded.t(), - alpha=A_scale * B_scale, - out_dtype=out_dtype, - bias=bias, - transpose_result=B.fuse_transpose_cusparselt, - ) - sparse_result = ( - sparse_result if B.fuse_transpose_cusparselt else sparse_result.t() - ) - return sparse_result[:row, :] + + # cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result. + # Note that this limits us to per-tensor scalig only. + sparse_result = torch._cslt_sparse_mm( + A.packed, + B, + alpha=A_scale * B_scale, + out_dtype=out_dtype, + ) + return sparse_result