[sparse] semi-structured sparse refactor (#117302)

Summary:

This PR is a refactor of semi-structured sparsity support.

**deprecation**:

Before `torch.sparse.to_sparse_semi_structured` had a kwarg param
`transposed=False`, which has been removed. This kwarg was unused and
now thros a deprecation warning.

Namely, I've taken the subclassing implementation that xFormers has
created and brought it over to PyTorch, as part of our plan to upstream
runtime 2:4 sparsity.

I've also copied over all the op support that Daniel implemenented that
did not depend on the fast sparsification routines, into
`_sparse_semi_structured_ops.py`

With this subclass, all of our internal tests pass, as well as those in
xFormers.

The main change is that we now define a base subclass,
`SparseSemiStructuredTensor` that is inherited from for each of the
specific backends.

We also now can arbitrarily override the sparse dispatch table with
`_load_dispatch_table()`, idea being this is still general enough
where users don't need to modify pytorch source code to get their model
working.

This also adds in padding support and stores alg_id and fuse_transpose
as flags on the tensor, instead of hardcoding them.

There still remains two components in xFormers that will need to be
ported over eventually:
- the autograd functions  (`Sparsify24`, `Sparsify24_like`)
- fast sparsification routines that they rely on

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302
Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
This commit is contained in:
Jesse Cai 2024-02-12 17:12:20 -08:00 committed by PyTorch MergeBot
parent 2536c5186e
commit 16369816a2
4 changed files with 583 additions and 427 deletions

View File

@ -6,9 +6,10 @@ import unittest
import torch
from torch import nn
from torch.sparse.semi_structured import (
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
from torch.sparse import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured,
)
@ -36,7 +37,7 @@ from torch.utils._triton import has_triton
CUSPARSELT_NUM_ALG_IDS = 4
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
_IS_SM8X = False
@ -315,7 +316,7 @@ class TestSparseSemiStructured(TestCase):
with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
torch.mm(A_sparse.t(), B)
@ -357,7 +358,7 @@ class TestSparseSemiStructured(TestCase):
with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
sparse_result = torch.mm(A, B_sparse)
@ -438,7 +439,10 @@ class TestSparseSemiStructured(TestCase):
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_min_sparse_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
if backend == "cutlass":
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
elif backend == "cusparselt":
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)

View File

@ -6,7 +6,12 @@ from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
# Semi structured sparsity support
from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured
from .semi_structured import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured
)
# A workaround to support both TorchScript and MyPy:
from typing import TYPE_CHECKING
@ -27,6 +32,8 @@ __all__ = [
'softmax',
'log_softmax',
'SparseSemiStructuredTensor',
'SparseSemiStructuredTensorCUTLASS',
'SparseSemiStructuredTensorCUSPARSELT',
'to_sparse_semi_structured',
'as_sparse_gradcheck',
]

View File

@ -0,0 +1,166 @@
import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",
"semi_sparse_indices",
"semi_sparse_t",
"semi_sparse_view",
"semi_sparse_detach",
"semi_sparse_mm",
"semi_sparse_addmm",
"semi_sparse_linear",
]
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def fallback_dispatcher(func, types, args, kwargs):
with no_dispatch():
return func(*args)
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
return A.packed[:num_kept_elements:].view(m, -1)
else:
return A.packed.detach()
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
metadata = A.packed[num_kept_elements:].view(m, -1)
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
else:
return A.meta
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
self = args[0]
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
assert len(self.shape) == 2
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,
meta=self.meta_t,
packed_t=self.packed,
meta_t=self.meta,
threads_masks=self.threads_masks.transpose(0, 1)
if self.threads_masks is not None
else None,
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt,
)
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
self, shape = args
if tuple(shape) != self.shape:
raise NotImplementedError(
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
)
return self
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
assert len(args) == 1
self = args[0]
return self.__class__(
shape=self.shape,
packed=self.packed,
meta=self.meta,
packed_t=self.packed_t,
meta_t=self.meta_t,
threads_masks=self.threads_masks,
requires_grad=False,
)
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
row, col = B.shape
B_padded = A._pad_dense_input(B)
res = A._mm(B_padded)
return res[:, :col]
else:
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B._pad_dense_input(A)
res = B_t._mm(A_padded.t()).t()
return res[:row, :]
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 3
bias, A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if bias.ndim != 1:
raise NotImplementedError(
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
)
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B_t._pad_dense_input(A)
result = B_t._mm(A_padded.t(), bias=bias).t()
return result[:row, :]
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) in [2, 3]
A, B = args[:2]
bias = args[2] if len(args) == 3 else None
shape = A.shape
A_2d = A.view(-1, shape[-1])
if bias is None:
res = A_2d @ B.t()
else:
res = semi_sparse_addmm(
func=None,
types=None,
args=[bias, A_2d, B.t()],
)
return res.view(*shape[:-1], -1)

View File

@ -1,93 +1,112 @@
import warnings
from collections import namedtuple
from typing import Any, Optional
from typing import Any, Optional, Tuple, List, Callable, Dict
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_values,
semi_sparse_indices,
semi_sparse_detach,
semi_sparse_t,
semi_sparse_view,
semi_sparse_mm,
semi_sparse_addmm,
semi_sparse_linear,
)
__all__ = [
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
]
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
"_SEMI_STRUCTURED_SPARSE_CONFIG", "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols"
"_SEMI_STRUCTURED_SPARSE_CONFIG",
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
)
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
# torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16) for CUTLASS, cuSPASRELt has a 32 x 32 min sparse shape
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4)
}
class SparseSemiStructuredTensor(torch.Tensor):
"""This class implementes semi-structured sparsity as a Tensor subclass.
"""
This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
structured sparsity.
Currently, this class supports 2:4 sparsity for int8, float16 and bfloat16 dtypes.
We also support 1:2 sparsity for float32 dtype.
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
Note that as such, this class cannot be insantiated directly.
This subclass stores the dense tensor in a compressed form by only storing the specified elements and corresponding metadata.
The subclass supports two backend, either CUTLASS or cuSPASRELt.
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
compressed tensor = [ specified elements of original tensor | metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata.
For CUTLASS backend, elements of original tensor and metadata are kept in separate tensors.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
and sparse_semi_structured_from_dense for conversion to the compressed format.
When PyTorch is compiled with cuSPARSELt support, this subclass will call into _cslt_sparse_mm for sparse mm and
_cslt_compress to convert into the compressed format.
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
"""
_FUSE_TRANSPOSE = False
_FORCE_CUTLASS = True
_PROTOTYPE_WARNING_SHOWN = False
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
SPARSE_DISPATCH: Dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
threads_masks: Optional[torch.Tensor]
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
@staticmethod
def __new__(
def __new__( # noqa: PYI034
cls,
original_tensor: Optional[torch.Tensor],
original_shape: Optional[torch.Size] = None,
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
meta_tensor_cutlass: Optional[torch.Tensor] = None,
transposed: bool = False,
shape: torch.Size,
packed: Optional[torch.Tensor],
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
threads_masks: Optional[torch.Tensor],
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
):
"""
Create a new instance of the class.
Create a new instance of the tensor subclass from the compressed sparse representation.
When original_tensor is passed in, we compress it and store the compresed representation.
We can also create new instance of the class from the compressed representation without the original tensor.
We have the option to create the subclass with the compressed representations of both X and X', for training.
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
Args:
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
original_shape: The shape of the original dense tensor
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
transposed: Whether the tensor is transposed or not.
shape: The shape of the original dense tensor
packed: The compressed representation of the original dense tensor
meta: The metadata of the original dense tensor, if it is stored separately
packed_t: The compressed representation of the transposed original dense tensor
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
Used for pointwise ops.
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
with a matmul, which is useful in the case of 2:4 sparse training.
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
ValueError: If all of the tensor arguments are None.
"""
assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None)
if not cls._PROTOTYPE_WARNING_SHOWN:
warnings.warn(
(
@ -100,387 +119,189 @@ class SparseSemiStructuredTensor(torch.Tensor):
)
cls._PROTOTYPE_WARNING_SHOWN = True
if original_tensor is not None:
previous_tensor = original_tensor
original_shape = original_tensor.shape
elif compressed_tensor_cusparselt is not None:
previous_tensor = compressed_tensor_cusparselt
elif sparse_tensor_cutlass is not None:
previous_tensor = sparse_tensor_cutlass
# Because this only runs onces, we also load the dispatch table here as well.
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
if packed is not None:
previous_tensor = packed
elif packed_t is not None:
previous_tensor = packed_t
else:
raise ValueError("All of the tensor arguments are None!")
raise ValueError("At least one of packed or packed_t must be provided")
kwargs = {}
kwargs["device"] = previous_tensor.device # type: ignore[assignment]
kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment]
kwargs["layout"] = previous_tensor.layout # type: ignore[assignment]
kwargs["requires_grad"] = False # type: ignore[assignment]
kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined]
def __init__(
self,
original_tensor: Optional[torch.Tensor],
original_shape: Optional[torch.Size] = None,
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
meta_tensor_cutlass: Optional[torch.Tensor] = None,
transposed: bool = False,
) -> None:
"""SparseSemiStructuredTensor constructor.
Args:
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
original_shape: The shape of the original dense tensor
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
transposed: Whether the tensor is transposed or not.
Returns:
None
Raises:
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
"""
# if original tensor is passed in, we need to compress it and store the compressed representation.
if original_tensor is not None:
# TODO right now we have unified checks and constraints for cuSPARSELt and CUTLASS, these are not actually the same.
# We should consolidate similar checks here and leave backend specific checks like shape in the op implementation.
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
# check dtype
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
)
# check shape
m, n = original_tensor.shape
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].sparse_min_rows
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].sparse_min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)
compressed_tensor_cusparselt = None
sparse_tensor_cutlass = None
meta_tensor_cutlass = None
if self._FORCE_CUTLASS:
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
)
sparse_tensor_cutlass, meta_tensor_cutlass = sparse_semi_structured_from_dense_cutlass(original_tensor)
else:
# use cuSPARSELt
compressed_tensor_cusparselt = torch._cslt_compress(original_tensor)
# set values
self.original_tensor = None
self.compressed_tensor_cusparselt = compressed_tensor_cusparselt
self.sparse_tensor_cutlass = sparse_tensor_cutlass
self.meta_tensor_cutlass = meta_tensor_cutlass
self.transposed = transposed
self.original_shape = original_shape
def __tensor_flatten__(self):
if self.compressed_tensor_cusparselt is None:
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
else:
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
original_shape, transposed = meta
if len(inner_tensors) == 2:
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
compressed_tensor_cusparselt = None
elif len(inner_tensors) == 1:
sparse_tensor_cutlass = None
meta_tensor_cutlass = None
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
else:
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
return SparseSemiStructuredTensor(
None,
original_shape=original_shape,
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
sparse_tensor_cutlass=sparse_tensor_cutlass,
meta_tensor_cutlass=meta_tensor_cutlass,
transposed=transposed,
)
@staticmethod
def __get_indices_dtype(values_dtype):
if values_dtype == torch.int8:
return torch.int32
elif values_dtype in (torch.float16, torch.bfloat16, torch.float32):
return torch.int16
else:
raise RuntimeError(f"Datatype {values_dtype} is not supported!")
return None
tensor.packed = packed
tensor.meta = meta
tensor.packed_t = packed_t
tensor.meta_t = meta_t
tensor.threads_masks = threads_masks
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
tensor.alg_id_cusparselt = alg_id_cusparselt
return tensor
def __repr__(self) -> str: # type: ignore[override]
"""Return string representation of SparseSemiStructuredTensor
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"
Returns:
str: String representation
def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (
self.shape,
self.fuse_transpose_cusparselt,
self.alg_id_cusparselt,
self.requires_grad,
)
return inner_tensors, tensor_meta
Raises:
None
"""
return (
f"SparseSemiStructuredTensor(shape={self.shape}, "
f"transposed={self.transposed}"
f"values={self.values()}"
f"metadata={self.indices()})"
@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta : Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
threads_masks=inner_tensors.get("threads_masks", None),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
)
__torch_function__ = torch._C._disabled_torch_function_impl
def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor:
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func._overloadpacket not in cls.SPARSE_DISPATCH:
raise NotImplementedError(
f"{cls.__name__} only supports a specific set of operations, "
f"can't perform requested op ({func.__name__})"
)
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
@classmethod
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
"""
Loads the op overload sparse dispatch table for the current class.
"""
if getattr(cls, "SPARSE_DISPATCH", None) is None:
cls.SPARSE_DISPATCH = {
torch.ops.aten.values: semi_sparse_values,
torch.ops.aten.indices: semi_sparse_indices,
torch.ops.aten.is_same_size: fallback_dispatcher,
torch.ops.aten.detach_: fallback_dispatcher,
torch.ops.aten.detach: semi_sparse_detach,
torch.ops.aten.t: semi_sparse_t,
torch.ops.aten.view: semi_sparse_view,
torch.ops.aten.mm: semi_sparse_mm,
torch.ops.aten.matmul: semi_sparse_mm,
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
# check contiguous
if not original_tensor.is_contiguous():
raise RuntimeError(
"Error original_tensor is not contiguous!"
"Only contiguous tensors are currently supported."
)
# check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
)
# check shape
m, n = original_tensor.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)
@classmethod
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
"""
Calculates padding for dense tensor and pads tensor if necessary.
If padding is not required, this function returns the original tensor.
"""
# only 2d matmul
assert original_tensor.dim() == 2
assert dense_input.dim() == 2
# check shape
m, n = original_tensor.shape
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_rows
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_cols
m, n = dense_input.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
# calculate padding
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
if to_pad_m or to_pad_n:
return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m))
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
else:
return original_tensor
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
In the future we plan to also add in support for cuSPARSELt kernels.
Args:
func: The function being dispatched.
types: The types of the arguments.
args: The arguments passed to the function.
kwargs: The keyword arguments passed to the function.
Returns:
Any: The result of the dispatched operation.
Raises:
NotImplementedError: If the dispatched operation is not implemented.
"""
# Since this code runs below autograd, a detach corresponds to only returning a new object
if func is torch.ops.aten.detach.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
original_shape=args[0].shape,
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
transposed=args[0].transposed,
)
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
if func is torch.ops.aten.t.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
# transpose shape
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
transposed=not args[0].transposed,
)
# handle addmm
if func is torch.ops.aten.addmm.default:
bias, input_A, input_B = args
# Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
# CUTLASS only supports the first input to be sparse for a given matmul.
# cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
# We support second matrix sparse matmul by taking advantage of some transpose properties:
# This is also why we want an odd number of transposed for second matrix sparse vs an even number
# of transpose calss for first matrix sparse.
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
if isinstance(input_B, cls) and input_B.transposed:
row, col = input_A.shape
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
if input_B.compressed_tensor_cusparselt is None:
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
res = torch._sparse_semi_structured_linear(
input_A_padded,
input_B.sparse_tensor_cutlass,
input_B.meta_tensor_cutlass,
bias=bias
)
else:
res = torch._cslt_sparse_mm(
input_B.compressed_tensor_cusparselt,
input_A_padded.t(),
bias=bias, # type: ignore[arg-type]
transpose_result=cls._FUSE_TRANSPOSE,
)
res = res if cls._FUSE_TRANSPOSE else res.t()
return res[:row, :]
# handle mm
if func is torch.ops.aten.mm.default:
input_A, input_B = args
# first element sparse
if isinstance(input_A, cls) and not input_A.transposed:
row, col = input_B.shape
input_B_padded = input_A._pad_tensor_for_matmul(input_B)
if input_A.compressed_tensor_cusparselt is None:
assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None
res = torch._sparse_semi_structured_linear(
input_B_padded.t(),
input_A.sparse_tensor_cutlass,
input_A.meta_tensor_cutlass
).t()
else:
res = torch._cslt_sparse_mm(
input_A.compressed_tensor_cusparselt,
input_B_padded,
bias=None, # type: ignore[arg-type]
)
return res[:, :col]
# second element sparse
elif isinstance(input_B, cls) and input_B.transposed:
row, col = input_A.shape
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
if input_B.compressed_tensor_cusparselt is None:
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
res = torch._sparse_semi_structured_linear(
input_A_padded,
input_B.sparse_tensor_cutlass,
input_B.meta_tensor_cutlass,
)
else:
res = torch._cslt_sparse_mm(
input_B.compressed_tensor_cusparselt,
input_A_padded.t(),
bias=None, # type: ignore[arg-type]
transpose_result=cls._FUSE_TRANSPOSE,
)
res = res if cls._FUSE_TRANSPOSE else res.t()
return res[:row, :]
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
# so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
if func is torch.ops.aten.linear.default:
input_tensor, weight, bias = args
# squash input_tensor to 2d
shape = input_tensor.shape
input_tensor_2d = input_tensor.view(-1, shape[-1])
row, col = input_tensor_2d.shape
# this is a noop if already padded
input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d)
if isinstance(weight, cls):
if weight.compressed_tensor_cusparselt is None:
assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None
res = torch._sparse_semi_structured_linear(
input_tensor_2d_padded,
weight.sparse_tensor_cutlass,
weight.meta_tensor_cutlass,
bias=bias
)
else:
res = torch._cslt_sparse_mm(
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
input_tensor_2d_padded.t(),
bias=bias,
transpose_result=cls._FUSE_TRANSPOSE
)
res = res if cls._FUSE_TRANSPOSE else res.t()
return res[:row, :].view(*shape[:-1], -1)
# handle values
if func is torch.ops.aten.values.default:
if args[0].compressed_tensor_cusparselt is None:
return args[0].sparse_tensor_cutlass.detach()
else:
m, k = args[0].shape
num_kept_elements = m * k // 2
return args[0].compressed_tensor_cusparselt[:num_kept_elements].view(m, k // 2)
# handle indices
if func is torch.ops.aten.indices.default:
if args[0].compressed_tensor_cusparselt is None:
return args[0].meta_tensor_cutlass
else:
m, k = args[0].shape
num_kept_elements = m * k // 2
metadata = args[0].compressed_tensor_cusparselt[num_kept_elements:].view(m, -1)
indices_dtype = SparseSemiStructuredTensor.__get_indices_dtype(
args[0].dtype
)
return metadata.view(indices_dtype)
error_string = "\n".join(
[f"func {func} with args: "]
+ [f"arg{i}: {arg}" for i, arg in enumerate(args)]
)
raise NotImplementedError(error_string)
return dense_input
def to_dense(self):
if self.compressed_tensor_cusparselt is not None:
raise RuntimeError("Converting to dense is not yet supported by cuSPARSELt backend!")
col = self.shape[-1]
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_to_dense_cutlass,
)
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
return sparse_semi_structured_to_dense_cutlass(
self.sparse_tensor_cutlass,
self.meta_tensor_cutlass,
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError
def to_sparse_semi_structured(
@ -493,18 +314,14 @@ def to_sparse_semi_structured(
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of a block size given the dtype
- torch.float16 (r, c) must be >= and a multiple of 64
- torch.int8 (r, c) must be >= and a multiple of 128
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
Args:
original_tensor (Tensor): the dense tensor to convert
transposed (bool, optional): whether the dense tensor is transposed
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
Raises:
None
Example:
@ -518,22 +335,184 @@ def to_sparse_semi_structured(
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A)
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
>>> A_sparse.values()
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
>>> A_sparse.indices()
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
dtype=torch.int16))
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
"""
return SparseSemiStructuredTensor(
original_tensor, original_shape=original_tensor.shape, transposed=transposed
if transposed:
raise DeprecationWarning(
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release."
"SparseSemiStructuredTensor only support contiguous input tensors. "
)
sparse_subclass = (
torch.sparse.SparseSemiStructuredTensorCUTLASS
if SparseSemiStructuredTensor._FORCE_CUTLASS
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
)
return sparse_subclass.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
and sparse_semi_structured_from_dense for conversion to the compressed format.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
}
@classmethod
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUTLASS":
cls._validate_device_dim_dtype_shape(original_tensor)
(
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
meta=meta_tensor_cutlass,
packed_t=None,
meta_t=None,
threads_masks=None,
requires_grad=original_tensor.requires_grad,
)
def to_dense(self):
assert self.meta is not None and self.packed is not None
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
)
if self.meta.ndim == 2
else super().to_dense()
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
cls_name = self.__class__.__name__
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{cls_name}` matmul: Broadcasting is not implemented"
)
if self.packed is None or self.meta is None:
raise NotImplementedError(
f"`{cls_name}` matmul: operation is not supported"
)
else:
res = torch._sparse_semi_structured_linear(
B.t(), self.packed, self.meta, bias=bias
).t()
return res[: self.shape[0]]
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
"""
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
packed = [ specified elements of original tensor | metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
attributes respectively.
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
}
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
meta=None,
packed_t=None,
meta_t=None,
threads_masks=None,
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
requires_grad=original_tensor.requires_grad,
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
)
if B.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
"This operation is only supported when A and B have the same data type."
)
if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type."
)
if self.packed is None:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported"
)
else:
res = torch._cslt_sparse_mm(
self.packed,
B,
bias=bias,
transpose_result=self.fuse_transpose_cusparselt,
alg_id=self.alg_id_cusparselt,
)
return res.t() if self.fuse_transpose_cusparselt else res