mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[sparse] add search for optimal alg_id to torch.compile (#137427)
Summary: This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal alg_id and cache it when running with `torch.compile` Seeing speedups on both bfloat16 and float8 dtypes: <img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b"> <img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6"> * `torch._cslt_sparse_mm_search` has been modified to return optimal split-k parameters as well as max alg_id. * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
b4cfb9c014
commit
39bfba3f56
|
|
@ -3357,7 +3357,7 @@
|
|||
dispatch:
|
||||
CUDA: _cslt_compress
|
||||
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,7 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <ATen/cuda/CUDASparse.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cusparse.h>
|
||||
#include <cstdint>
|
||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||
|
||||
#if AT_CUSPARSELT_ENABLED()
|
||||
|
||||
#include <cusparseLt.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp
|
||||
|
|
@ -56,6 +43,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
|
|||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
type = CUDA_R_8F_E4M3;
|
||||
compression_factor = 10;
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
|
|
@ -103,7 +91,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
|
|||
return compressed_tensor;
|
||||
}
|
||||
|
||||
std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
|
|
@ -111,6 +99,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
bool search_alg_id
|
||||
)
|
||||
{
|
||||
|
|
@ -169,6 +159,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||
output_type = CUDA_R_8F_E4M3;
|
||||
C_type = CUDA_R_16F;
|
||||
compute_type = CUSPARSE_COMPUTE_32F;
|
||||
compression_factor = 10;
|
||||
break;
|
||||
#endif
|
||||
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
||||
|
|
@ -335,10 +326,21 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
|
||||
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
|
||||
|
||||
// set alg_id
|
||||
// set matmul search params
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
||||
|
||||
cusparseLtSplitKMode_t splitKMode;
|
||||
int max_alg_id;
|
||||
if (split_k != 1) {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K, &split_k, sizeof(split_k)));
|
||||
|
||||
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL : CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode)));
|
||||
}
|
||||
|
||||
// set tensor_alpha_mode and alpha pointer for matmul
|
||||
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
|
||||
auto alpha_ptr = α
|
||||
|
|
@ -381,9 +383,23 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||
&stream,
|
||||
1));
|
||||
|
||||
// get alg_id used
|
||||
// get matmul params used
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K,
|
||||
&split_k, sizeof(split_k)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||
&splitKMode, sizeof(splitKMode)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID,
|
||||
&max_alg_id, sizeof(max_alg_id)));
|
||||
|
||||
|
||||
}
|
||||
else {
|
||||
// do normal matmul
|
||||
|
|
@ -411,7 +427,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||
// destroy plan
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
|
||||
|
||||
return {alg_id, res};
|
||||
return {res, alg_id, split_k, splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL, max_alg_id};
|
||||
}
|
||||
|
||||
at::Tensor _cslt_sparse_mm(
|
||||
|
|
@ -421,7 +437,9 @@ at::Tensor _cslt_sparse_mm(
|
|||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int64_t alg_id
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel
|
||||
)
|
||||
{
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
|
|
@ -432,8 +450,10 @@ at::Tensor _cslt_sparse_mm(
|
|||
out_dtype_opt,
|
||||
transpose_result,
|
||||
(int) alg_id,
|
||||
(int) split_k,
|
||||
split_k_one_kernel,
|
||||
false);
|
||||
return std::get<1>(result);
|
||||
return std::get<0>(result);
|
||||
}
|
||||
|
||||
int64_t _cslt_sparse_mm_search(
|
||||
|
|
@ -445,7 +465,10 @@ int64_t _cslt_sparse_mm_search(
|
|||
bool transpose_result
|
||||
)
|
||||
{
|
||||
TORCH_WARN_ONCE("torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel= true;
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
|
|
@ -454,11 +477,12 @@ int64_t _cslt_sparse_mm_search(
|
|||
out_dtype_opt,
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
true);
|
||||
return (int64_t) std::get<0>(result);
|
||||
return (int64_t) std::get<1>(result);
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
#else // No cuSPARSELt support, throw error if these functions are called.
|
||||
|
|
@ -476,7 +500,9 @@ at::Tensor _cslt_sparse_mm(
|
|||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
bool transpose_result,
|
||||
int64_t alg_id)
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel)
|
||||
{
|
||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||
}
|
||||
|
|
|
|||
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <ATen/cuda/CUDASparse.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cusparse.h>
|
||||
#include <cstdint>
|
||||
|
||||
#if AT_CUSPARSELT_ENABLED()
|
||||
#include <cusparseLt.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
at::Tensor _cslt_compress(const Tensor& sparse_input);
|
||||
|
||||
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
bool search_alg_id
|
||||
);
|
||||
|
||||
at::Tensor _cslt_sparse_mm(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel
|
||||
);
|
||||
|
||||
int64_t _cslt_sparse_mm_search(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result
|
||||
);
|
||||
|
||||
} // namespace at::native
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
import argparse
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch import nn
|
||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
|
||||
|
||||
torch.set_printoptions(
|
||||
precision=2,
|
||||
threshold=None,
|
||||
edgeitems=16,
|
||||
linewidth=480,
|
||||
profile=None,
|
||||
sci_mode=False,
|
||||
)
|
||||
|
||||
|
||||
# helper model definition for pruner
|
||||
class Model(nn.Module):
|
||||
def __init__(self, m, k, dtype=None):
|
||||
super().__init__()
|
||||
# transposed so reversed
|
||||
self.linear = nn.Linear(k, m)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
):
|
||||
"""
|
||||
This function returns a 1:2 sparse matrix of size (r, c).
|
||||
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
||||
"""
|
||||
|
||||
choices = [[0, 1], [1, 0]]
|
||||
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
||||
|
||||
return (
|
||||
torch.tensor(mask_entries, dtype=dtype, device=device)
|
||||
.reshape(r, c)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def test_linear(m, k, n, dtype, contiguous, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
|
||||
input_tensor = torch.zeros(n, k).to(dtype).cuda()
|
||||
model = Model(m, k).to(dtype).cuda().eval()
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
dense_output = model(input_tensor)
|
||||
print(dense_output.shape)
|
||||
|
||||
# sparsify weights
|
||||
model.linear.weight = nn.Parameter(
|
||||
to_sparse_semi_structured(
|
||||
sparse_weight,
|
||||
)
|
||||
)
|
||||
|
||||
sparse_output = model(input_tensor)
|
||||
print(sparse_output.shape)
|
||||
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "linear",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
def test_tensor(m, k, n, dtype, contiguous, backend):
|
||||
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
B = torch.zeros(k, n).to(dtype).cuda()
|
||||
bias = torch.rand(n).to(dtype).cuda()
|
||||
|
||||
sA = to_sparse_semi_structured(A)
|
||||
|
||||
# torch.mm calculation
|
||||
if dtype is not torch.int8:
|
||||
dense_output = torch.mm(A, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(A, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
else:
|
||||
print("int8 baseline not supported")
|
||||
dense_output = torch.mm(sA, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
sparse_output = torch.mm(sA, B)
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "tensor",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype_lookup = {
|
||||
"int8": torch.int8,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"nvidia-bert",
|
||||
"nvidia-fixed-k",
|
||||
"nvidia-fixed-mn",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=dtype_lookup.keys(),
|
||||
default="fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
|
||||
)
|
||||
parser.add_argument("-contiguous", action="store_true")
|
||||
parser.add_argument("-e2e", action="store_true")
|
||||
parser.add_argument("-save", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.e2e:
|
||||
eval_fn = test_linear
|
||||
else:
|
||||
eval_fn = test_tensor
|
||||
|
||||
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
|
||||
dtype = dtype_lookup[args.dtype]
|
||||
|
||||
if args.mode == "nvidia-bert":
|
||||
bert_shapes = [
|
||||
(3072, 1024, 16384),
|
||||
(4096, 1024, 16384),
|
||||
(1024, 1024, 16384),
|
||||
(1024, 4096, 16384),
|
||||
]
|
||||
results = (
|
||||
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
|
||||
for (m, k, n) in tqdm(bert_shapes)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-k":
|
||||
mn_vals = [
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
6144,
|
||||
7168,
|
||||
8192,
|
||||
9216,
|
||||
10240,
|
||||
11264,
|
||||
12288,
|
||||
13312,
|
||||
14336,
|
||||
15360,
|
||||
16384,
|
||||
17408,
|
||||
18432,
|
||||
19456,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
|
||||
for mn in tqdm(mn_vals)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-mn":
|
||||
k_vals = [
|
||||
2560,
|
||||
3840,
|
||||
5120,
|
||||
6400,
|
||||
7680,
|
||||
8960,
|
||||
10240,
|
||||
11520,
|
||||
12800,
|
||||
14080,
|
||||
15360,
|
||||
16640,
|
||||
17920,
|
||||
19200,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
|
||||
for k in tqdm(k_vals)
|
||||
)
|
||||
|
||||
df = pd.DataFrame.from_records(results)
|
||||
if args.save:
|
||||
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
|
||||
df.to_csv(save_file)
|
||||
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
|
||||
print(df)
|
||||
|
|
@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
|||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
def fn(x):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output = fn(x)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output = torch.compile(fn)(x)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
|
@ -1156,8 +1155,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||
alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel)
|
||||
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
|
|
@ -1174,6 +1174,16 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
@inference_dtypes
|
||||
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
def test_cusparselt_backend(self):
|
||||
version = _get_torch_cuda_version()
|
||||
assert torch.backends.cusparselt.is_available()
|
||||
|
|
@ -1181,9 +1191,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||
# CUDA 11.8 has cuSPARSELt v0.4.0 support
|
||||
if version == (11, 8):
|
||||
assert torch.backends.cusparselt.version() == 400
|
||||
assert torch.backends.cusparselt.get_max_alg_id() == 4
|
||||
# CUDA 12.1 has cuSPARSELt v0.5.2 support
|
||||
elif version == (12, 1):
|
||||
assert torch.backends.cusparselt.version() == 502
|
||||
assert torch.backends.cusparselt.get_max_alg_id() == 4
|
||||
# CUDA 12.4+ has cuSPARSELt v0.6.2 support
|
||||
elif version >= (12, 4):
|
||||
assert torch.backends.cusparselt.version() == 602
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
||||
|
|
@ -144,6 +144,12 @@ aten__sparse_semi_structured_mm = ExternKernelChoice(
|
|||
has_out_variant=False,
|
||||
)
|
||||
|
||||
aten__cslt_sparse_mm = ExternKernelChoice(
|
||||
torch._cslt_sparse_mm,
|
||||
"at::_cslt_sparse_mm",
|
||||
has_out_variant=False,
|
||||
)
|
||||
|
||||
|
||||
def _is_int8_mat(mat):
|
||||
return mat.get_dtype() in (torch.int8, torch.uint8)
|
||||
|
|
@ -521,6 +527,124 @@ def tuned_sparse_semi_structured_mm(
|
|||
)
|
||||
|
||||
|
||||
@register_lowering(aten._cslt_sparse_mm, type_promotion_kind=None)
|
||||
def tuned_cslt_sparse_mm(
|
||||
mat1_compressed,
|
||||
mat2,
|
||||
bias=None,
|
||||
alpha=None,
|
||||
out_dtype=None,
|
||||
transpose_result=False,
|
||||
alg_id=0,
|
||||
split_k=1,
|
||||
split_k_one_kernel=True,
|
||||
layout=None,
|
||||
):
|
||||
from torch._inductor.select_algorithm import AlgorithmSelectorCache, realize_inputs
|
||||
|
||||
mat1_compressed, mat2 = realize_inputs(mat1_compressed, mat2)
|
||||
input_nodes: Tuple[Any, ...] = (mat1_compressed, mat2)
|
||||
k, n = mat2.get_size()
|
||||
|
||||
is_8bit_input_type = mat1_compressed.dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
compression_factor = 10 if is_8bit_input_type else 9
|
||||
m = (mat1_compressed.get_numel() * 16) // (compression_factor * k)
|
||||
|
||||
from torch._inductor.ir import FixedLayout
|
||||
|
||||
if transpose_result:
|
||||
layout = FixedLayout(
|
||||
mat2.get_device(),
|
||||
out_dtype if out_dtype else mat2.get_dtype(),
|
||||
[n, m],
|
||||
[m, 1],
|
||||
)
|
||||
else:
|
||||
layout = FixedLayout(
|
||||
mat2.get_device(),
|
||||
out_dtype if out_dtype else mat2.get_dtype(),
|
||||
[m, n],
|
||||
[n, 1],
|
||||
)
|
||||
# workaround for Inductor not supporting optional tensor input arguments
|
||||
if bias is not None:
|
||||
bias = realize_inputs(bias)
|
||||
input_nodes = input_nodes + (bias,)
|
||||
|
||||
if alpha is not None:
|
||||
alpha = realize_inputs(alpha)
|
||||
input_nodes = input_nodes + (alpha,)
|
||||
|
||||
# cuSPARSELt alg_id search, not that we cannot use
|
||||
# AlgorithmSelectorCache.benchmark_example_value() because this will return the base view
|
||||
# and mat2 needs to have transpose properties preserved for cslt mm
|
||||
(
|
||||
searched_alg_id,
|
||||
searched_split_k,
|
||||
searched_split_k_one_kernel,
|
||||
_,
|
||||
) = torch._C._cusparselt.mm_search( # type: ignore[attr-defined]
|
||||
AlgorithmSelectorCache.generate_example_value(
|
||||
V.graph.sizevars.size_hints(mat1_compressed.get_size()),
|
||||
V.graph.sizevars.size_hints(mat1_compressed.get_stride()),
|
||||
mat1_compressed.get_device(),
|
||||
mat1_compressed.dtype,
|
||||
mat1_compressed.layout.offset,
|
||||
),
|
||||
AlgorithmSelectorCache.generate_example_value(
|
||||
V.graph.sizevars.size_hints(mat2.get_size()),
|
||||
V.graph.sizevars.size_hints(mat2.get_stride()),
|
||||
mat2.get_device(),
|
||||
mat2.dtype,
|
||||
mat2.layout.offset,
|
||||
),
|
||||
AlgorithmSelectorCache.generate_example_value(
|
||||
V.graph.sizevars.size_hints(bias.get_size()),
|
||||
V.graph.sizevars.size_hints(bias.get_stride()),
|
||||
bias.get_device(),
|
||||
bias.dtype,
|
||||
bias.layout.offset,
|
||||
)
|
||||
if bias is not None
|
||||
else None,
|
||||
AlgorithmSelectorCache.generate_example_value(
|
||||
V.graph.sizevars.size_hints(alpha.get_size()),
|
||||
V.graph.sizevars.size_hints(alpha.get_stride()),
|
||||
alpha.get_device(),
|
||||
alpha.dtype,
|
||||
alpha.layout.offset,
|
||||
)
|
||||
if alpha is not None
|
||||
else None,
|
||||
out_dtype,
|
||||
transpose_result,
|
||||
)
|
||||
|
||||
baseline = aten__cslt_sparse_mm.bind(
|
||||
input_nodes,
|
||||
layout,
|
||||
out_dtype=out_dtype,
|
||||
alg_id=0,
|
||||
split_k=1,
|
||||
split_k_one_kernel=True,
|
||||
transpose_result=transpose_result,
|
||||
)
|
||||
baseline.description = f"ALG_ID: 0 SPLIT_K: 1 SPLIT_K_ONE_KERNEL: True TRANSPOSE_RESULT: {transpose_result}"
|
||||
searched = aten__cslt_sparse_mm.bind(
|
||||
input_nodes,
|
||||
layout,
|
||||
out_dtype=out_dtype,
|
||||
alg_id=searched_alg_id,
|
||||
split_k=searched_split_k,
|
||||
split_k_one_kernel=searched_split_k_one_kernel,
|
||||
transpose_result=transpose_result,
|
||||
)
|
||||
searched.description = f"ALG_ID: {searched_alg_id} SPLIT_K: {searched_split_k} SPLIT_K_ONE_KERNEL: {searched_split_k_one_kernel} TRANSPOSE_RESULT: {transpose_result}" # noqa: B950
|
||||
choices = [baseline, searched]
|
||||
|
||||
return autotune_select_algorithm("cslt_sparse_mm", choices, input_nodes, layout)
|
||||
|
||||
|
||||
def fallback_mixed_mm(mat1, mat2, *, out):
|
||||
return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
|
||||
|
||||
|
|
|
|||
|
|
@ -520,32 +520,44 @@ def meta__cslt_sparse_mm(
|
|||
alpha: Optional[Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
transpose_result: bool = False,
|
||||
alg_id: int = 0,
|
||||
split_k: int = 1,
|
||||
split_k_one_kernel: bool = False,
|
||||
):
|
||||
assert dense_B.dtype in {
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
|
||||
torch.float8_e4m3fn,
|
||||
}, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
|
||||
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
|
||||
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
|
||||
|
||||
is_int8_input_type = compressed_A.dtype == torch.int8
|
||||
compression_factor = 10 if is_int8_input_type else 9
|
||||
is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
compression_factor = 10 if is_8bit_input_type else 9
|
||||
k = dense_B.size(0)
|
||||
n = dense_B.size(1)
|
||||
m = (compressed_A.numel() * 16) // (compression_factor * k)
|
||||
if bias is not None:
|
||||
assert m == bias.size(0)
|
||||
|
||||
if is_8bit_input_type:
|
||||
assert not dense_B.is_contiguous()
|
||||
|
||||
if out_dtype is not None:
|
||||
assert is_int8_input_type and out_dtype in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
|
||||
assert (
|
||||
is_8bit_input_type
|
||||
and out_dtype
|
||||
in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
}
|
||||
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
output_shape = (n, m) if transpose_result else (m, n)
|
||||
result = dense_B.new_empty(output_shape, dtype=out_dtype)
|
||||
result = torch.empty(output_shape, dtype=out_dtype, device=compressed_A.device)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,12 +25,12 @@ if _cusparselt is not None:
|
|||
global __MAX_ALG_ID
|
||||
if __cusparselt_version is None:
|
||||
__cusparselt_version = _cusparselt.getVersionInt()
|
||||
if __cusparselt_version == 400:
|
||||
__MAX_ALG_ID = 4
|
||||
elif __cusparselt_version == 502:
|
||||
__MAX_ALG_ID = 5
|
||||
elif __cusparselt_version == 602:
|
||||
__MAX_ALG_ID = 37
|
||||
|
||||
# only way to get MAX_ALG_ID is to run a matmul
|
||||
A = torch.zeros(128, 128, dtype=torch.float16).cuda()
|
||||
A = torch._cslt_compress(A)
|
||||
B = torch.zeros(128, 128, dtype=torch.float16).cuda()
|
||||
_, _, _, __MAX_ALG_ID = _cusparselt.mm_search(A, B, None, None, None, False) # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
else:
|
||||
|
|
@ -52,6 +52,7 @@ def is_available() -> bool:
|
|||
|
||||
|
||||
def get_max_alg_id() -> Optional[int]:
|
||||
r"""Return the maximum algorithm id supported by the current version of cuSPARSELt"""
|
||||
if not _init():
|
||||
return None
|
||||
return __MAX_ALG_ID
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#ifdef USE_CUSPARSELT
|
||||
#include <cusparseLt.h>
|
||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -9,6 +9,34 @@ size_t getVersionInt() {
|
|||
return CUSPARSELT_VERSION;
|
||||
}
|
||||
|
||||
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||
const at::Tensor& compressed_A,
|
||||
const at::Tensor& dense_B,
|
||||
const std::optional<at::Tensor>& bias_opt,
|
||||
const std::optional<at::Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result) {
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel = true;
|
||||
auto result = at::native::_cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
bias_opt,
|
||||
alpha_opt,
|
||||
out_dtype_opt,
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
true);
|
||||
return {
|
||||
(int64_t)std::get<1>(result),
|
||||
(int64_t)std::get<2>(result),
|
||||
(bool)std::get<3>(result),
|
||||
(int64_t)std::get<4>(result)};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace torch::cuda::shared {
|
||||
|
|
@ -17,6 +45,7 @@ void initCusparseltBindings(PyObject* module) {
|
|||
auto m = py::handle(module).cast<py::module>();
|
||||
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
||||
cusparselt.def("getVersionInt", getVersionInt);
|
||||
cusparselt.def("mm_search", mmSearch);
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
|||
packed_t=self.packed_t,
|
||||
meta_t=self.meta_t,
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
||||
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=self.alg_id_cusparselt,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
|
@ -177,19 +179,37 @@ def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|||
|
||||
assert A.dtype == torch.float8_e4m3fn
|
||||
assert B.dtype == torch.float8_e4m3fn
|
||||
# only cuSPARSELt supports float8_e4m3fn currentl
|
||||
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
|
||||
assert A.packed is not None
|
||||
# Currently we only support per-tensor scaling, with float32 scales
|
||||
assert A_scale.numel() == 1 and B_scale.numel() == 1
|
||||
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
|
||||
|
||||
# cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result.
|
||||
# Note that this limits us to per-tensor scalig only.
|
||||
sparse_result = torch._cslt_sparse_mm(
|
||||
A.packed,
|
||||
B,
|
||||
alpha=A_scale * B_scale,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
return sparse_result
|
||||
assert A_scale.numel() == 1 and B_scale.numel() == 1
|
||||
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
|
||||
# only cuSPARSELt supports float8_e4m3fn currentl
|
||||
if isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT):
|
||||
assert A.packed is not None
|
||||
row, col = B.shape
|
||||
B_padded = A._pad_dense_input(B).contiguous().t()
|
||||
sparse_result = torch._cslt_sparse_mm(
|
||||
A.packed,
|
||||
B_padded,
|
||||
alpha=A_scale * B_scale,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return sparse_result[:, :col]
|
||||
else:
|
||||
assert isinstance(B, torch.sparse.SparseSemiStructuredTensor)
|
||||
assert B.packed is not None
|
||||
row, col = A.shape
|
||||
A_padded = B._pad_dense_input(A)
|
||||
sparse_result = torch._cslt_sparse_mm(
|
||||
B.packed,
|
||||
A_padded.t(),
|
||||
alpha=A_scale * B_scale,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias,
|
||||
transpose_result=B.fuse_transpose_cusparselt,
|
||||
)
|
||||
sparse_result = (
|
||||
sparse_result if B.fuse_transpose_cusparselt else sparse_result.t()
|
||||
)
|
||||
return sparse_result[:row, :]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user