mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[sparse] Add padding for dense matrices in semi-structured sparse (#110583)
Summary: Currently we have shape constraints in semi-structured sparsity for both CUTLASS and cuSPARSELt These shape constraints unfortunately apply to both the dense and sparse matrices in sparsedense matmul. This PR adds in support for calling `F.pad` in order to pad dense matrices to the right size with zeros and then pull out the corresponding rows from the resultant result matrix. We also throw a warning in this case. The tests have also been updated to take in a dense_input_shape parameter. Test Plan: ``` python test/test_sparse_semi_structured.py ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110583 Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
This commit is contained in:
parent
2b6f281e5c
commit
8db72a430d
|
|
@ -5,7 +5,7 @@ import pandas as pd
|
|||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch import nn
|
||||
from torch.sparse import to_sparse_semi_structured
|
||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ def rand_sparse_semi_structured_mask(
|
|||
|
||||
|
||||
def test_linear(m, k, n, dtype, contiguous, backend):
|
||||
SparseSemiStructuredTensor.fuse_transpose = contiguous
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
|
||||
input_tensor = torch.zeros(n, k).to(dtype).cuda()
|
||||
|
|
@ -61,6 +61,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
|
|||
).blocked_autorange()
|
||||
|
||||
dense_output = model(input_tensor)
|
||||
print(dense_output.shape)
|
||||
|
||||
# sparsify weights
|
||||
model.linear.weight = nn.Parameter(
|
||||
|
|
@ -70,6 +71,7 @@ def test_linear(m, k, n, dtype, contiguous, backend):
|
|||
)
|
||||
|
||||
sparse_output = model(input_tensor)
|
||||
print(sparse_output.shape)
|
||||
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ if torch.cuda.is_available():
|
|||
# check if cslt is available for now using this:
|
||||
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
||||
try:
|
||||
torch._cslt_compress(torch.ones(128, 128).cuda())
|
||||
torch._cslt_compress(torch.ones(128, 256).cuda())
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -127,7 +127,7 @@ class TestSparseSemiStructured(TestCase):
|
|||
def test_to_sparse_semi_structured(self, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
assert A.shape == A_sparse.shape
|
||||
|
|
@ -139,18 +139,18 @@ class TestSparseSemiStructured(TestCase):
|
|||
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NT(self, dtype, device, backend):
|
||||
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
||||
Ensure torch.mm(A_sparse, B.t()) is correct
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
||||
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
|
|
@ -162,7 +162,38 @@ class TestSparseSemiStructured(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError,
|
||||
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
else:
|
||||
dense_result = torch.mm(A, B)
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
||||
and will throw an error for int8 + padding
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8 and dense_input_shape in {(1, 128), (64, 128)}:
|
||||
# padding with int8 throws an error because transposing B yields a contiguous output
|
||||
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
||||
if backend == "cutlass":
|
||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
elif dtype is torch.int8:
|
||||
# test transpose
|
||||
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
||||
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
||||
|
|
@ -170,25 +201,23 @@ class TestSparseSemiStructured(TestCase):
|
|||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
dense_result = torch.mm(A, B)
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
# test transpose
|
||||
dense_result = torch.mm(A, B.t())
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_T(self, dtype, device, backend):
|
||||
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse.t(), B) throws error
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
||||
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
|
|
@ -197,16 +226,17 @@ class TestSparseSemiStructured(TestCase):
|
|||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_second_T(self, dtype, device, backend):
|
||||
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse.t()) is correct
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
B_sparse = to_sparse_semi_structured(B)
|
||||
|
||||
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
||||
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
|
|
@ -218,32 +248,18 @@ class TestSparseSemiStructured(TestCase):
|
|||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_cslt_sparse_mm_int8_in_fp16_out(self, device):
|
||||
"""
|
||||
This test is only needed for cuSPARSELt
|
||||
"""
|
||||
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8)
|
||||
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16)
|
||||
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_second_NT(self, dtype, device, backend):
|
||||
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse) throws error
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
B_sparse = to_sparse_semi_structured(B)
|
||||
|
||||
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
||||
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
|
|
@ -251,15 +267,32 @@ class TestSparseSemiStructured(TestCase):
|
|||
):
|
||||
sparse_result = torch.mm(A, B_sparse)
|
||||
|
||||
@parametrize("dense_input_shape", [(128, 128)])
|
||||
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
|
||||
"""
|
||||
This test is only needed for cuSPARSELt
|
||||
"""
|
||||
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=torch.int8)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8)
|
||||
|
||||
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
|
||||
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_linear(self, inference_mode, device, backend):
|
||||
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
||||
"""
|
||||
Test nn.Linear has the same numerics
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
input = torch.rand(64, 128, 128, device=device).half()
|
||||
model = nn.Linear(128, 128).to(device).half()
|
||||
input = torch.rand((dense_input_shape), device=device).half()
|
||||
model = nn.Linear(128, 256).to(device).half()
|
||||
m, n = model.weight.shape
|
||||
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
||||
# set masked weight
|
||||
|
|
@ -277,14 +310,15 @@ class TestSparseSemiStructured(TestCase):
|
|||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mlp(self, device, backend):
|
||||
def test_mlp(self, device, dense_input_shape, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
input = torch.rand(64, 768, 768, device=device).half()
|
||||
input = torch.rand(dense_input_shape, device=device).half()
|
||||
model = (
|
||||
nn.Sequential(
|
||||
nn.Linear(768, 3072),
|
||||
nn.Linear(3072, 768),
|
||||
nn.Linear(128, 256),
|
||||
nn.Linear(256, 128),
|
||||
)
|
||||
.half()
|
||||
.to(device)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
|
||||
_FUSE_TRANSPOSE = False
|
||||
_FORCE_CUTLASS = True
|
||||
_WARNING_SHOWN = False
|
||||
_PROTOTYPE_WARNING_SHOWN = False
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
|
|
@ -88,7 +88,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
"""
|
||||
assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None)
|
||||
|
||||
if not cls._WARNING_SHOWN:
|
||||
if not cls._PROTOTYPE_WARNING_SHOWN:
|
||||
warnings.warn(
|
||||
(
|
||||
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
|
||||
|
|
@ -98,7 +98,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
),
|
||||
UserWarning,
|
||||
)
|
||||
cls._WARNING_SHOWN = True
|
||||
cls._PROTOTYPE_WARNING_SHOWN = True
|
||||
|
||||
if original_tensor is not None:
|
||||
previous_tensor = original_tensor
|
||||
|
|
@ -232,6 +232,26 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates padding for dense tensor and pads tensor if necessary.
|
||||
If padding is not required, this function returns the original tensor.
|
||||
"""
|
||||
# only 2d matmul
|
||||
assert original_tensor.dim() == 2
|
||||
|
||||
# check shape
|
||||
m, n = original_tensor.shape
|
||||
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_rows
|
||||
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_cols
|
||||
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
|
||||
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
|
||||
if to_pad_m or to_pad_n:
|
||||
return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m))
|
||||
else:
|
||||
return original_tensor
|
||||
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
||||
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
|
||||
|
|
@ -290,38 +310,52 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
|
||||
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
|
||||
if isinstance(input_B, cls) and input_B.transposed:
|
||||
row, col = input_A.shape
|
||||
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
||||
if input_B.compressed_tensor_cusparselt is None:
|
||||
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
||||
return torch._sparse_semi_structured_linear(
|
||||
input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias
|
||||
)
|
||||
else:
|
||||
return torch._cslt_sparse_mm(
|
||||
input_B.compressed_tensor_cusparselt, input_A.T, bias # type: ignore[arg-type]
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type]
|
||||
).t()
|
||||
return res[:row, :]
|
||||
|
||||
# handle mm
|
||||
if func is torch.ops.aten.mm.default:
|
||||
input_A, input_B = args
|
||||
|
||||
# first element sparse
|
||||
if isinstance(input_A, cls) and not input_A.transposed:
|
||||
row, col = input_B.shape
|
||||
input_B_padded = input_A._pad_tensor_for_matmul(input_B)
|
||||
if input_A.compressed_tensor_cusparselt is None:
|
||||
assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None
|
||||
return torch._sparse_semi_structured_linear(
|
||||
input_B.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass
|
||||
).t()
|
||||
else:
|
||||
return torch._cslt_sparse_mm(
|
||||
input_A.compressed_tensor_cusparselt, input_B, None # type: ignore[arg-type]
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type]
|
||||
)
|
||||
return res[:, :col]
|
||||
|
||||
# second element sparse
|
||||
elif isinstance(input_B, cls) and input_B.transposed:
|
||||
row, col = input_A.shape
|
||||
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
||||
|
||||
if input_B.compressed_tensor_cusparselt is None:
|
||||
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
||||
return torch._sparse_semi_structured_linear(
|
||||
input_A, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
|
||||
)
|
||||
else:
|
||||
return torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A.T, None).t() # type: ignore[arg-type]
|
||||
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type]
|
||||
|
||||
return res[:row, :]
|
||||
|
||||
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
|
||||
# so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul
|
||||
|
|
@ -329,21 +363,29 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
if func is torch.ops.aten.linear.default:
|
||||
input_tensor, weight, bias = args
|
||||
shape = input_tensor.shape
|
||||
|
||||
input_tensor_2d = input_tensor.view(-1, shape[-1])
|
||||
row, col = input_tensor_2d.shape
|
||||
# this is a noop if already padded
|
||||
input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d)
|
||||
|
||||
if isinstance(weight, cls):
|
||||
if weight.compressed_tensor_cusparselt is None:
|
||||
assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None
|
||||
return torch._sparse_semi_structured_linear(
|
||||
input_tensor,
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_tensor_2d_padded,
|
||||
weight.sparse_tensor_cutlass,
|
||||
weight.meta_tensor_cutlass,
|
||||
bias=bias
|
||||
)
|
||||
else:
|
||||
return torch._cslt_sparse_mm(
|
||||
res = torch._cslt_sparse_mm(
|
||||
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
|
||||
input_tensor.view(-1, shape[-1]).t(),
|
||||
input_tensor_2d_padded.t(),
|
||||
bias
|
||||
).t().view(*shape[:-1], -1)
|
||||
).t()
|
||||
return res[:row, :].view(*shape[:-1], -1)
|
||||
|
||||
|
||||
# handle values
|
||||
if func is torch.ops.aten.values.default:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user