mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[sparse] semi-structured sparse refactor (#117302)
Summary: This PR is a refactor of semi-structured sparsity support. **deprecation**: Before `torch.sparse.to_sparse_semi_structured` had a kwarg param `transposed=False`, which has been removed. This kwarg was unused and now thros a deprecation warning. Namely, I've taken the subclassing implementation that xFormers has created and brought it over to PyTorch, as part of our plan to upstream runtime 2:4 sparsity. I've also copied over all the op support that Daniel implemenented that did not depend on the fast sparsification routines, into `_sparse_semi_structured_ops.py` With this subclass, all of our internal tests pass, as well as those in xFormers. The main change is that we now define a base subclass, `SparseSemiStructuredTensor` that is inherited from for each of the specific backends. We also now can arbitrarily override the sparse dispatch table with `_load_dispatch_table()`, idea being this is still general enough where users don't need to modify pytorch source code to get their model working. This also adds in padding support and stores alg_id and fuse_transpose as flags on the tensor, instead of hardcoding them. There still remains two components in xFormers that will need to be ported over eventually: - the autograd functions (`Sparsify24`, `Sparsify24_like`) - fast sparsification routines that they rely on Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302 Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
This commit is contained in:
parent
2536c5186e
commit
16369816a2
|
|
@ -6,9 +6,10 @@ import unittest
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.sparse.semi_structured import (
|
||||
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
|
||||
from torch.sparse import (
|
||||
SparseSemiStructuredTensor,
|
||||
SparseSemiStructuredTensorCUSPARSELT,
|
||||
SparseSemiStructuredTensorCUTLASS,
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ from torch.utils._triton import has_triton
|
|||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
|
||||
_IS_SM8X = False
|
||||
|
|
@ -315,7 +316,7 @@ class TestSparseSemiStructured(TestCase):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
|
||||
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
||||
):
|
||||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
|
|
@ -357,7 +358,7 @@ class TestSparseSemiStructured(TestCase):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
|
||||
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
||||
):
|
||||
sparse_result = torch.mm(A, B_sparse)
|
||||
|
||||
|
|
@ -438,7 +439,10 @@ class TestSparseSemiStructured(TestCase):
|
|||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_min_sparse_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
|
||||
if backend == "cutlass":
|
||||
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
elif backend == "cusparselt":
|
||||
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,12 @@ from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
|||
from torch import Tensor
|
||||
|
||||
# Semi structured sparsity support
|
||||
from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
from .semi_structured import (
|
||||
SparseSemiStructuredTensor,
|
||||
SparseSemiStructuredTensorCUSPARSELT,
|
||||
SparseSemiStructuredTensorCUTLASS,
|
||||
to_sparse_semi_structured
|
||||
)
|
||||
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -27,6 +32,8 @@ __all__ = [
|
|||
'softmax',
|
||||
'log_softmax',
|
||||
'SparseSemiStructuredTensor',
|
||||
'SparseSemiStructuredTensorCUTLASS',
|
||||
'SparseSemiStructuredTensorCUSPARSELT',
|
||||
'to_sparse_semi_structured',
|
||||
'as_sparse_gradcheck',
|
||||
]
|
||||
|
|
|
|||
166
torch/sparse/_semi_structured_ops.py
Normal file
166
torch/sparse/_semi_structured_ops.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"fallback_dispatcher",
|
||||
"semi_sparse_values",
|
||||
"semi_sparse_indices",
|
||||
"semi_sparse_t",
|
||||
"semi_sparse_view",
|
||||
"semi_sparse_detach",
|
||||
"semi_sparse_mm",
|
||||
"semi_sparse_addmm",
|
||||
"semi_sparse_linear",
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch():
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def fallback_dispatcher(func, types, args, kwargs):
|
||||
with no_dispatch():
|
||||
return func(*args)
|
||||
|
||||
|
||||
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 1
|
||||
A = args[0]
|
||||
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
|
||||
assert A.packed is not None
|
||||
if A.meta is None:
|
||||
m, k = A.shape
|
||||
num_kept_elements = m * k // 2
|
||||
return A.packed[:num_kept_elements:].view(m, -1)
|
||||
else:
|
||||
return A.packed.detach()
|
||||
|
||||
|
||||
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 1
|
||||
A = args[0]
|
||||
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
|
||||
assert A.packed is not None
|
||||
if A.meta is None:
|
||||
m, k = A.shape
|
||||
num_kept_elements = m * k // 2
|
||||
metadata = A.packed[num_kept_elements:].view(m, -1)
|
||||
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
|
||||
else:
|
||||
return A.meta
|
||||
|
||||
|
||||
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 1
|
||||
self = args[0]
|
||||
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
|
||||
assert len(self.shape) == 2
|
||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||
return self.__class__(
|
||||
torch.Size([self.shape[-1], self.shape[0]]),
|
||||
packed=self.packed_t,
|
||||
meta=self.meta_t,
|
||||
packed_t=self.packed,
|
||||
meta_t=self.meta,
|
||||
threads_masks=self.threads_masks.transpose(0, 1)
|
||||
if self.threads_masks is not None
|
||||
else None,
|
||||
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
||||
)
|
||||
|
||||
|
||||
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 2
|
||||
self, shape = args
|
||||
if tuple(shape) != self.shape:
|
||||
raise NotImplementedError(
|
||||
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
||||
assert len(args) == 1
|
||||
self = args[0]
|
||||
return self.__class__(
|
||||
shape=self.shape,
|
||||
packed=self.packed,
|
||||
meta=self.meta,
|
||||
packed_t=self.packed_t,
|
||||
meta_t=self.meta_t,
|
||||
threads_masks=self.threads_masks,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 2
|
||||
A, B = args
|
||||
if A.ndim != 2 or B.ndim != 2:
|
||||
raise NotImplementedError(
|
||||
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
|
||||
)
|
||||
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
|
||||
row, col = B.shape
|
||||
B_padded = A._pad_dense_input(B)
|
||||
res = A._mm(B_padded)
|
||||
return res[:, :col]
|
||||
else:
|
||||
B_t = B.t()
|
||||
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
|
||||
row, col = A.shape
|
||||
A_padded = B._pad_dense_input(A)
|
||||
res = B_t._mm(A_padded.t()).t()
|
||||
return res[:row, :]
|
||||
|
||||
|
||||
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) == 3
|
||||
bias, A, B = args
|
||||
if A.ndim != 2 or B.ndim != 2:
|
||||
raise NotImplementedError(
|
||||
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
|
||||
)
|
||||
if bias.ndim != 1:
|
||||
raise NotImplementedError(
|
||||
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
|
||||
)
|
||||
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
|
||||
raise NotImplementedError(
|
||||
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
|
||||
)
|
||||
B_t = B.t()
|
||||
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
|
||||
row, col = A.shape
|
||||
A_padded = B_t._pad_dense_input(A)
|
||||
result = B_t._mm(A_padded.t(), bias=bias).t()
|
||||
return result[:row, :]
|
||||
|
||||
|
||||
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
assert len(args) in [2, 3]
|
||||
A, B = args[:2]
|
||||
bias = args[2] if len(args) == 3 else None
|
||||
|
||||
shape = A.shape
|
||||
A_2d = A.view(-1, shape[-1])
|
||||
|
||||
if bias is None:
|
||||
res = A_2d @ B.t()
|
||||
else:
|
||||
res = semi_sparse_addmm(
|
||||
func=None,
|
||||
types=None,
|
||||
args=[bias, A_2d, B.t()],
|
||||
)
|
||||
|
||||
return res.view(*shape[:-1], -1)
|
||||
|
|
@ -1,93 +1,112 @@
|
|||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
semi_sparse_values,
|
||||
semi_sparse_indices,
|
||||
semi_sparse_detach,
|
||||
semi_sparse_t,
|
||||
semi_sparse_view,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_addmm,
|
||||
semi_sparse_linear,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SparseSemiStructuredTensor",
|
||||
"SparseSemiStructuredTensorCUTLASS",
|
||||
"SparseSemiStructuredTensorCUSPARSELT",
|
||||
"to_sparse_semi_structured",
|
||||
]
|
||||
|
||||
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
|
||||
"_SEMI_STRUCTURED_SPARSE_CONFIG", "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols"
|
||||
"_SEMI_STRUCTURED_SPARSE_CONFIG",
|
||||
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
|
||||
)
|
||||
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
|
||||
# torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16) for CUTLASS, cuSPASRELt has a 32 x 32 min sparse shape
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 128, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4)
|
||||
}
|
||||
|
||||
|
||||
class SparseSemiStructuredTensor(torch.Tensor):
|
||||
"""This class implementes semi-structured sparsity as a Tensor subclass.
|
||||
"""
|
||||
This class implementes semi-structured sparsity as a Tensor subclass.
|
||||
|
||||
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
|
||||
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
|
||||
structured sparsity.
|
||||
|
||||
Currently, this class supports 2:4 sparsity for int8, float16 and bfloat16 dtypes.
|
||||
We also support 1:2 sparsity for float32 dtype.
|
||||
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
|
||||
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
|
||||
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
|
||||
Note that as such, this class cannot be insantiated directly.
|
||||
|
||||
This subclass stores the dense tensor in a compressed form by only storing the specified elements and corresponding metadata.
|
||||
|
||||
The subclass supports two backend, either CUTLASS or cuSPASRELt.
|
||||
|
||||
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
|
||||
|
||||
compressed tensor = [ specified elements of original tensor | metadata ]
|
||||
|
||||
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
|
||||
The rest of the tensor is metadata.
|
||||
|
||||
For CUTLASS backend, elements of original tensor and metadata are kept in separate tensors.
|
||||
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
||||
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
|
||||
When PyTorch is compiled with cuSPARSELt support, this subclass will call into _cslt_sparse_mm for sparse mm and
|
||||
_cslt_compress to convert into the compressed format.
|
||||
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
|
||||
- `def from_dense()` - backend specific compression routines
|
||||
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
|
||||
"""
|
||||
|
||||
_FUSE_TRANSPOSE = False
|
||||
_FORCE_CUTLASS = True
|
||||
_PROTOTYPE_WARNING_SHOWN = False
|
||||
_DEFAULT_ALG_ID: int = 0
|
||||
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
||||
_FORCE_CUTLASS: bool = True
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
SPARSE_DISPATCH: Dict[Callable, Callable]
|
||||
|
||||
packed: Optional[torch.Tensor]
|
||||
meta: Optional[torch.Tensor]
|
||||
packed_t: Optional[torch.Tensor]
|
||||
meta_t: Optional[torch.Tensor]
|
||||
threads_masks: Optional[torch.Tensor]
|
||||
fuse_transpose_cusparselt: bool
|
||||
alg_id_cusparselt: int
|
||||
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
def __new__( # noqa: PYI034
|
||||
cls,
|
||||
original_tensor: Optional[torch.Tensor],
|
||||
original_shape: Optional[torch.Size] = None,
|
||||
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
|
||||
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
|
||||
meta_tensor_cutlass: Optional[torch.Tensor] = None,
|
||||
transposed: bool = False,
|
||||
shape: torch.Size,
|
||||
packed: Optional[torch.Tensor],
|
||||
meta: Optional[torch.Tensor],
|
||||
packed_t: Optional[torch.Tensor],
|
||||
meta_t: Optional[torch.Tensor],
|
||||
threads_masks: Optional[torch.Tensor],
|
||||
fuse_transpose_cusparselt: bool = False,
|
||||
alg_id_cusparselt: int = 0,
|
||||
requires_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a new instance of the class.
|
||||
Create a new instance of the tensor subclass from the compressed sparse representation.
|
||||
|
||||
When original_tensor is passed in, we compress it and store the compresed representation.
|
||||
We can also create new instance of the class from the compressed representation without the original tensor.
|
||||
We have the option to create the subclass with the compressed representations of both X and X', for training.
|
||||
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
|
||||
|
||||
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
|
||||
|
||||
Args:
|
||||
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
||||
original_shape: The shape of the original dense tensor
|
||||
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
|
||||
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
|
||||
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
|
||||
transposed: Whether the tensor is transposed or not.
|
||||
shape: The shape of the original dense tensor
|
||||
packed: The compressed representation of the original dense tensor
|
||||
meta: The metadata of the original dense tensor, if it is stored separately
|
||||
packed_t: The compressed representation of the transposed original dense tensor
|
||||
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
|
||||
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
|
||||
Used for pointwise ops.
|
||||
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
|
||||
with a matmul, which is useful in the case of 2:4 sparse training.
|
||||
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor wrapper subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If all of the tensor arguments are None.
|
||||
|
||||
"""
|
||||
assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None)
|
||||
|
||||
if not cls._PROTOTYPE_WARNING_SHOWN:
|
||||
warnings.warn(
|
||||
(
|
||||
|
|
@ -100,387 +119,189 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
)
|
||||
cls._PROTOTYPE_WARNING_SHOWN = True
|
||||
|
||||
if original_tensor is not None:
|
||||
previous_tensor = original_tensor
|
||||
original_shape = original_tensor.shape
|
||||
elif compressed_tensor_cusparselt is not None:
|
||||
previous_tensor = compressed_tensor_cusparselt
|
||||
elif sparse_tensor_cutlass is not None:
|
||||
previous_tensor = sparse_tensor_cutlass
|
||||
# Because this only runs onces, we also load the dispatch table here as well.
|
||||
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
|
||||
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
|
||||
cls._load_dispatch_table()
|
||||
|
||||
if packed is not None:
|
||||
previous_tensor = packed
|
||||
elif packed_t is not None:
|
||||
previous_tensor = packed_t
|
||||
else:
|
||||
raise ValueError("All of the tensor arguments are None!")
|
||||
raise ValueError("At least one of packed or packed_t must be provided")
|
||||
|
||||
kwargs = {}
|
||||
kwargs["device"] = previous_tensor.device # type: ignore[assignment]
|
||||
kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment]
|
||||
kwargs["layout"] = previous_tensor.layout # type: ignore[assignment]
|
||||
kwargs["requires_grad"] = False # type: ignore[assignment]
|
||||
kwargs = {
|
||||
"device": previous_tensor.device,
|
||||
"dtype": previous_tensor.dtype,
|
||||
"layout": previous_tensor.layout,
|
||||
"requires_grad": requires_grad,
|
||||
}
|
||||
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_tensor: Optional[torch.Tensor],
|
||||
original_shape: Optional[torch.Size] = None,
|
||||
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
|
||||
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
|
||||
meta_tensor_cutlass: Optional[torch.Tensor] = None,
|
||||
transposed: bool = False,
|
||||
) -> None:
|
||||
"""SparseSemiStructuredTensor constructor.
|
||||
|
||||
Args:
|
||||
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
||||
original_shape: The shape of the original dense tensor
|
||||
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
|
||||
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
|
||||
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
|
||||
transposed: Whether the tensor is transposed or not.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
|
||||
"""
|
||||
# if original tensor is passed in, we need to compress it and store the compressed representation.
|
||||
if original_tensor is not None:
|
||||
# TODO right now we have unified checks and constraints for cuSPARSELt and CUTLASS, these are not actually the same.
|
||||
# We should consolidate similar checks here and leave backend specific checks like shape in the op implementation.
|
||||
|
||||
# check device
|
||||
if not original_tensor.is_cuda:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
||||
"Only CUDA tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dim
|
||||
if original_tensor.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
||||
"Only 2d tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dtype
|
||||
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
|
||||
)
|
||||
|
||||
# check shape
|
||||
m, n = original_tensor.shape
|
||||
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
|
||||
original_tensor.dtype
|
||||
].sparse_min_rows
|
||||
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
|
||||
original_tensor.dtype
|
||||
].sparse_min_cols
|
||||
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
|
||||
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
||||
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
||||
)
|
||||
|
||||
compressed_tensor_cusparselt = None
|
||||
sparse_tensor_cutlass = None
|
||||
meta_tensor_cutlass = None
|
||||
if self._FORCE_CUTLASS:
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
)
|
||||
|
||||
sparse_tensor_cutlass, meta_tensor_cutlass = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||
else:
|
||||
# use cuSPARSELt
|
||||
compressed_tensor_cusparselt = torch._cslt_compress(original_tensor)
|
||||
|
||||
# set values
|
||||
self.original_tensor = None
|
||||
self.compressed_tensor_cusparselt = compressed_tensor_cusparselt
|
||||
self.sparse_tensor_cutlass = sparse_tensor_cutlass
|
||||
self.meta_tensor_cutlass = meta_tensor_cutlass
|
||||
self.transposed = transposed
|
||||
self.original_shape = original_shape
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
if self.compressed_tensor_cusparselt is None:
|
||||
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
||||
else:
|
||||
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||
original_shape, transposed = meta
|
||||
|
||||
if len(inner_tensors) == 2:
|
||||
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
|
||||
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
|
||||
compressed_tensor_cusparselt = None
|
||||
elif len(inner_tensors) == 1:
|
||||
sparse_tensor_cutlass = None
|
||||
meta_tensor_cutlass = None
|
||||
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
|
||||
else:
|
||||
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
|
||||
|
||||
return SparseSemiStructuredTensor(
|
||||
None,
|
||||
original_shape=original_shape,
|
||||
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=meta_tensor_cutlass,
|
||||
transposed=transposed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __get_indices_dtype(values_dtype):
|
||||
if values_dtype == torch.int8:
|
||||
return torch.int32
|
||||
elif values_dtype in (torch.float16, torch.bfloat16, torch.float32):
|
||||
return torch.int16
|
||||
else:
|
||||
raise RuntimeError(f"Datatype {values_dtype} is not supported!")
|
||||
return None
|
||||
tensor.packed = packed
|
||||
tensor.meta = meta
|
||||
tensor.packed_t = packed_t
|
||||
tensor.meta_t = meta_t
|
||||
tensor.threads_masks = threads_masks
|
||||
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
|
||||
tensor.alg_id_cusparselt = alg_id_cusparselt
|
||||
return tensor
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
assert hasattr(self, "shape")
|
||||
return f"{self.__class__.__name__}(shape={self.shape})"
|
||||
|
||||
Returns:
|
||||
str: String representation
|
||||
def __tensor_flatten__(
|
||||
self,
|
||||
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
|
||||
inner_tensors = list(
|
||||
filter(lambda x: getattr(self, x) is not None, self.__slots__)
|
||||
)
|
||||
tensor_meta = (
|
||||
self.shape,
|
||||
self.fuse_transpose_cusparselt,
|
||||
self.alg_id_cusparselt,
|
||||
self.requires_grad,
|
||||
)
|
||||
return inner_tensors, tensor_meta
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
return (
|
||||
f"SparseSemiStructuredTensor(shape={self.shape}, "
|
||||
f"transposed={self.transposed}"
|
||||
f"values={self.values()}"
|
||||
f"metadata={self.indices()})"
|
||||
@classmethod
|
||||
def __tensor_unflatten__(
|
||||
cls,
|
||||
inner_tensors,
|
||||
tensor_meta : Tuple[torch.Size, bool, int, bool],
|
||||
outer_size,
|
||||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
||||
return cls(
|
||||
shape=shape,
|
||||
packed=inner_tensors.get("packed", None),
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
threads_masks=inner_tensors.get("threads_masks", None),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor:
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
||||
if func._overloadpacket not in cls.SPARSE_DISPATCH:
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} only supports a specific set of operations, "
|
||||
f"can't perform requested op ({func.__name__})"
|
||||
)
|
||||
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
|
||||
"""
|
||||
Loads the op overload sparse dispatch table for the current class.
|
||||
"""
|
||||
if getattr(cls, "SPARSE_DISPATCH", None) is None:
|
||||
cls.SPARSE_DISPATCH = {
|
||||
torch.ops.aten.values: semi_sparse_values,
|
||||
torch.ops.aten.indices: semi_sparse_indices,
|
||||
torch.ops.aten.is_same_size: fallback_dispatcher,
|
||||
torch.ops.aten.detach_: fallback_dispatcher,
|
||||
torch.ops.aten.detach: semi_sparse_detach,
|
||||
torch.ops.aten.t: semi_sparse_t,
|
||||
torch.ops.aten.view: semi_sparse_view,
|
||||
torch.ops.aten.mm: semi_sparse_mm,
|
||||
torch.ops.aten.matmul: semi_sparse_mm,
|
||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||
torch.ops.aten.linear: semi_sparse_linear,
|
||||
}
|
||||
if custom_dispatch_table is not None:
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
|
||||
@classmethod
|
||||
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
|
||||
"""
|
||||
Assert that the given tensor is valid for semi-structured sparse compression.
|
||||
"""
|
||||
# check device
|
||||
if not original_tensor.is_cuda:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
||||
"Only CUDA tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dim
|
||||
if original_tensor.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
||||
"Only 2d tensors are currently supported."
|
||||
)
|
||||
|
||||
# check contiguous
|
||||
if not original_tensor.is_contiguous():
|
||||
raise RuntimeError(
|
||||
"Error original_tensor is not contiguous!"
|
||||
"Only contiguous tensors are currently supported."
|
||||
)
|
||||
|
||||
# check dtype
|
||||
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
|
||||
)
|
||||
|
||||
# check shape
|
||||
m, n = original_tensor.shape
|
||||
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
|
||||
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
|
||||
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
|
||||
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
||||
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates padding for dense tensor and pads tensor if necessary.
|
||||
If padding is not required, this function returns the original tensor.
|
||||
"""
|
||||
# only 2d matmul
|
||||
assert original_tensor.dim() == 2
|
||||
assert dense_input.dim() == 2
|
||||
|
||||
# check shape
|
||||
m, n = original_tensor.shape
|
||||
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_rows
|
||||
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_cols
|
||||
m, n = dense_input.shape
|
||||
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
|
||||
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
|
||||
|
||||
# calculate padding
|
||||
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
|
||||
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
|
||||
if to_pad_m or to_pad_n:
|
||||
return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m))
|
||||
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
|
||||
else:
|
||||
return original_tensor
|
||||
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
||||
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
|
||||
|
||||
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
|
||||
In the future we plan to also add in support for cuSPARSELt kernels.
|
||||
|
||||
Args:
|
||||
func: The function being dispatched.
|
||||
types: The types of the arguments.
|
||||
args: The arguments passed to the function.
|
||||
kwargs: The keyword arguments passed to the function.
|
||||
|
||||
Returns:
|
||||
Any: The result of the dispatched operation.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the dispatched operation is not implemented.
|
||||
"""
|
||||
# Since this code runs below autograd, a detach corresponds to only returning a new object
|
||||
if func is torch.ops.aten.detach.default:
|
||||
return SparseSemiStructuredTensor(
|
||||
args[0].original_tensor,
|
||||
original_shape=args[0].shape,
|
||||
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
|
||||
transposed=args[0].transposed,
|
||||
)
|
||||
|
||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||
if func is torch.ops.aten.t.default:
|
||||
return SparseSemiStructuredTensor(
|
||||
args[0].original_tensor,
|
||||
# transpose shape
|
||||
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
|
||||
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
|
||||
transposed=not args[0].transposed,
|
||||
)
|
||||
|
||||
# handle addmm
|
||||
if func is torch.ops.aten.addmm.default:
|
||||
bias, input_A, input_B = args
|
||||
|
||||
# Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
|
||||
# CUTLASS only supports the first input to be sparse for a given matmul.
|
||||
# cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
|
||||
|
||||
# We support second matrix sparse matmul by taking advantage of some transpose properties:
|
||||
# This is also why we want an odd number of transposed for second matrix sparse vs an even number
|
||||
# of transpose calss for first matrix sparse.
|
||||
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
|
||||
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
|
||||
if isinstance(input_B, cls) and input_B.transposed:
|
||||
row, col = input_A.shape
|
||||
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
||||
|
||||
if input_B.compressed_tensor_cusparselt is None:
|
||||
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_A_padded,
|
||||
input_B.sparse_tensor_cutlass,
|
||||
input_B.meta_tensor_cutlass,
|
||||
bias=bias
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_B.compressed_tensor_cusparselt,
|
||||
input_A_padded.t(),
|
||||
bias=bias, # type: ignore[arg-type]
|
||||
transpose_result=cls._FUSE_TRANSPOSE,
|
||||
)
|
||||
res = res if cls._FUSE_TRANSPOSE else res.t()
|
||||
return res[:row, :]
|
||||
|
||||
# handle mm
|
||||
if func is torch.ops.aten.mm.default:
|
||||
input_A, input_B = args
|
||||
|
||||
# first element sparse
|
||||
if isinstance(input_A, cls) and not input_A.transposed:
|
||||
row, col = input_B.shape
|
||||
input_B_padded = input_A._pad_tensor_for_matmul(input_B)
|
||||
|
||||
if input_A.compressed_tensor_cusparselt is None:
|
||||
assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_B_padded.t(),
|
||||
input_A.sparse_tensor_cutlass,
|
||||
input_A.meta_tensor_cutlass
|
||||
).t()
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_A.compressed_tensor_cusparselt,
|
||||
input_B_padded,
|
||||
bias=None, # type: ignore[arg-type]
|
||||
)
|
||||
return res[:, :col]
|
||||
|
||||
# second element sparse
|
||||
elif isinstance(input_B, cls) and input_B.transposed:
|
||||
row, col = input_A.shape
|
||||
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
||||
|
||||
if input_B.compressed_tensor_cusparselt is None:
|
||||
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_A_padded,
|
||||
input_B.sparse_tensor_cutlass,
|
||||
input_B.meta_tensor_cutlass,
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_B.compressed_tensor_cusparselt,
|
||||
input_A_padded.t(),
|
||||
bias=None, # type: ignore[arg-type]
|
||||
transpose_result=cls._FUSE_TRANSPOSE,
|
||||
)
|
||||
res = res if cls._FUSE_TRANSPOSE else res.t()
|
||||
return res[:row, :]
|
||||
|
||||
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
|
||||
# so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul
|
||||
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
|
||||
if func is torch.ops.aten.linear.default:
|
||||
input_tensor, weight, bias = args
|
||||
# squash input_tensor to 2d
|
||||
shape = input_tensor.shape
|
||||
input_tensor_2d = input_tensor.view(-1, shape[-1])
|
||||
row, col = input_tensor_2d.shape
|
||||
# this is a noop if already padded
|
||||
input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d)
|
||||
|
||||
if isinstance(weight, cls):
|
||||
if weight.compressed_tensor_cusparselt is None:
|
||||
assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
input_tensor_2d_padded,
|
||||
weight.sparse_tensor_cutlass,
|
||||
weight.meta_tensor_cutlass,
|
||||
bias=bias
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
|
||||
input_tensor_2d_padded.t(),
|
||||
bias=bias,
|
||||
transpose_result=cls._FUSE_TRANSPOSE
|
||||
)
|
||||
res = res if cls._FUSE_TRANSPOSE else res.t()
|
||||
return res[:row, :].view(*shape[:-1], -1)
|
||||
|
||||
|
||||
# handle values
|
||||
if func is torch.ops.aten.values.default:
|
||||
if args[0].compressed_tensor_cusparselt is None:
|
||||
return args[0].sparse_tensor_cutlass.detach()
|
||||
else:
|
||||
m, k = args[0].shape
|
||||
num_kept_elements = m * k // 2
|
||||
return args[0].compressed_tensor_cusparselt[:num_kept_elements].view(m, k // 2)
|
||||
|
||||
# handle indices
|
||||
if func is torch.ops.aten.indices.default:
|
||||
if args[0].compressed_tensor_cusparselt is None:
|
||||
return args[0].meta_tensor_cutlass
|
||||
else:
|
||||
m, k = args[0].shape
|
||||
num_kept_elements = m * k // 2
|
||||
metadata = args[0].compressed_tensor_cusparselt[num_kept_elements:].view(m, -1)
|
||||
indices_dtype = SparseSemiStructuredTensor.__get_indices_dtype(
|
||||
args[0].dtype
|
||||
)
|
||||
return metadata.view(indices_dtype)
|
||||
|
||||
error_string = "\n".join(
|
||||
[f"func {func} with args: "]
|
||||
+ [f"arg{i}: {arg}" for i, arg in enumerate(args)]
|
||||
)
|
||||
raise NotImplementedError(error_string)
|
||||
|
||||
return dense_input
|
||||
|
||||
def to_dense(self):
|
||||
if self.compressed_tensor_cusparselt is not None:
|
||||
raise RuntimeError("Converting to dense is not yet supported by cuSPARSELt backend!")
|
||||
col = self.shape[-1]
|
||||
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
|
||||
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
)
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
|
||||
raise NotImplementedError
|
||||
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.sparse_tensor_cutlass,
|
||||
self.meta_tensor_cutlass,
|
||||
)
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def to_sparse_semi_structured(
|
||||
|
|
@ -493,18 +314,14 @@ def to_sparse_semi_structured(
|
|||
|
||||
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
|
||||
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
|
||||
Additionally, your tensor must be a positive multiple of a block size given the dtype
|
||||
|
||||
- torch.float16 (r, c) must be >= and a multiple of 64
|
||||
- torch.int8 (r, c) must be >= and a multiple of 128
|
||||
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
|
||||
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
|
||||
|
||||
Args:
|
||||
original_tensor (Tensor): the dense tensor to convert
|
||||
transposed (bool, optional): whether the dense tensor is transposed
|
||||
|
||||
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
|
||||
Returns:
|
||||
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
|
||||
|
||||
Raises:
|
||||
None
|
||||
Example:
|
||||
|
|
@ -518,22 +335,184 @@ def to_sparse_semi_structured(
|
|||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
||||
>>> A_sparse = to_sparse_semi_structured(A)
|
||||
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
|
||||
>>> A_sparse.values()
|
||||
tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
...,
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
|
||||
metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
>>> A_sparse.indices()
|
||||
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
...,
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
|
||||
dtype=torch.int16))
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
|
||||
"""
|
||||
return SparseSemiStructuredTensor(
|
||||
original_tensor, original_shape=original_tensor.shape, transposed=transposed
|
||||
if transposed:
|
||||
raise DeprecationWarning(
|
||||
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release."
|
||||
"SparseSemiStructuredTensor only support contiguous input tensors. "
|
||||
)
|
||||
|
||||
sparse_subclass = (
|
||||
torch.sparse.SparseSemiStructuredTensorCUTLASS
|
||||
if SparseSemiStructuredTensor._FORCE_CUTLASS
|
||||
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
|
||||
)
|
||||
return sparse_subclass.from_dense(original_tensor)
|
||||
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
This class implements semi-structured sparsity for the CUTLASS backend.
|
||||
|
||||
In this implementation, the specified elements and metadata are stored seprately,
|
||||
in packed and meta respectively.
|
||||
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
||||
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dense(
|
||||
cls, original_tensor: torch.Tensor
|
||||
) -> "SparseSemiStructuredTensorCUTLASS":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
(
|
||||
sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass,
|
||||
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=sparse_tensor_cutlass,
|
||||
meta=meta_tensor_cutlass,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
|
||||
)
|
||||
cls_name = self.__class__.__name__
|
||||
if self.ndim != 2 or B.ndim != 2:
|
||||
raise NotImplementedError(
|
||||
f"`{cls_name}` matmul: Broadcasting is not implemented"
|
||||
)
|
||||
if self.packed is None or self.meta is None:
|
||||
raise NotImplementedError(
|
||||
f"`{cls_name}` matmul: operation is not supported"
|
||||
)
|
||||
else:
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
B.t(), self.packed, self.meta, bias=bias
|
||||
).t()
|
||||
return res[: self.shape[0]]
|
||||
|
||||
|
||||
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
"""
|
||||
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
|
||||
packed = [ specified elements of original tensor | metadata ]
|
||||
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
|
||||
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
|
||||
attributes respectively.
|
||||
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
return cls(
|
||||
shape=original_tensor.shape,
|
||||
packed=torch._cslt_compress(original_tensor),
|
||||
meta=None,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
|
||||
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
|
||||
)
|
||||
if self.ndim != 2 or B.ndim != 2:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
|
||||
)
|
||||
if B.dtype != self.dtype:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
|
||||
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
|
||||
"This operation is only supported when A and B have the same data type."
|
||||
)
|
||||
if bias is not None and bias.dtype != self.dtype:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
|
||||
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
||||
"This operation is only supported when A, B and C have the same data type."
|
||||
)
|
||||
if self.packed is None:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: operation is not supported"
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
self.packed,
|
||||
B,
|
||||
bias=bias,
|
||||
transpose_result=self.fuse_transpose_cusparselt,
|
||||
alg_id=self.alg_id_cusparselt,
|
||||
)
|
||||
return res.t() if self.fuse_transpose_cusparselt else res
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user