mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for fp32, which this PR enables.(thanks @alexsamardzic). Furthermore, this PR also updates the sparse_config to take into account the different shape constraints for sparse and dense matrices. Technically, cuSPARSELt supports smaller sparse matrix constraints as it seens to pad to the CUTLASS constraints under the hood. However, in practice small sparse matrices are not commonly used and we care more about the dense constraints for LLM inference. For now, we keep the CUTLASS constraints in place for both cuSPARSELt and CUTLASS tensors This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors. Test Plan: ``` python test/test_sparse_semi_structured.py ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550 Approved by: https://github.com/cpuhrsch
543 lines
24 KiB
Python
543 lines
24 KiB
Python
import warnings
|
|
from collections import namedtuple
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"SparseSemiStructuredTensor",
|
|
"to_sparse_semi_structured",
|
|
]
|
|
|
|
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
|
|
"_SEMI_STRUCTURED_SPARSE_CONFIG", "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols"
|
|
)
|
|
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
|
|
# torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16) for CUTLASS, cuSPASRELt has a 32 x 32 min sparse shape
|
|
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 128, 16, 16),
|
|
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
|
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
|
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4)
|
|
}
|
|
|
|
|
|
class SparseSemiStructuredTensor(torch.Tensor):
|
|
"""This class implementes semi-structured sparsity as a Tensor subclass.
|
|
|
|
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
|
|
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
|
|
structured sparsity.
|
|
|
|
Currently, this class supports 2:4 sparsity for int8, float16 and bfloat16 dtypes.
|
|
We also support 1:2 sparsity for float32 dtype.
|
|
|
|
This subclass stores the dense tensor in a compressed form by only storing the specified elements and corresponding metadata.
|
|
|
|
The subclass supports two backend, either CUTLASS or cuSPASRELt.
|
|
|
|
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
|
|
|
|
compressed tensor = [ specified elements of original tensor | metadata ]
|
|
|
|
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
|
|
The rest of the tensor is metadata.
|
|
|
|
For CUTLASS backend, elements of original tensor and metadata are kept in separate tensors.
|
|
|
|
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
|
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
|
|
|
When PyTorch is compiled with cuSPARSELt support, this subclass will call into _cslt_sparse_mm for sparse mm and
|
|
_cslt_compress to convert into the compressed format.
|
|
"""
|
|
|
|
_FUSE_TRANSPOSE = False
|
|
_FORCE_CUTLASS = True
|
|
_PROTOTYPE_WARNING_SHOWN = False
|
|
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
original_tensor: Optional[torch.Tensor],
|
|
original_shape: Optional[torch.Size] = None,
|
|
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
|
|
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
|
|
meta_tensor_cutlass: Optional[torch.Tensor] = None,
|
|
transposed: bool = False,
|
|
):
|
|
"""
|
|
Create a new instance of the class.
|
|
|
|
When original_tensor is passed in, we compress it and store the compresed representation.
|
|
We can also create new instance of the class from the compressed representation without the original tensor.
|
|
|
|
Args:
|
|
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
|
original_shape: The shape of the original dense tensor
|
|
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
|
|
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
|
|
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
|
|
transposed: Whether the tensor is transposed or not.
|
|
|
|
Returns:
|
|
torch.Tensor: A torch.Tensor wrapper subclass.
|
|
|
|
Raises:
|
|
ValueError: If all of the tensor arguments are None.
|
|
|
|
"""
|
|
assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None)
|
|
|
|
if not cls._PROTOTYPE_WARNING_SHOWN:
|
|
warnings.warn(
|
|
(
|
|
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
|
|
"and will change in the near future. Please open a Github issue "
|
|
"for features requests and see our documentation on the torch.sparse "
|
|
"module for further information about the project."
|
|
),
|
|
UserWarning,
|
|
)
|
|
cls._PROTOTYPE_WARNING_SHOWN = True
|
|
|
|
if original_tensor is not None:
|
|
previous_tensor = original_tensor
|
|
original_shape = original_tensor.shape
|
|
elif compressed_tensor_cusparselt is not None:
|
|
previous_tensor = compressed_tensor_cusparselt
|
|
elif sparse_tensor_cutlass is not None:
|
|
previous_tensor = sparse_tensor_cutlass
|
|
else:
|
|
raise ValueError("All of the tensor arguments are None!")
|
|
|
|
kwargs = {}
|
|
kwargs["device"] = previous_tensor.device # type: ignore[assignment]
|
|
kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment]
|
|
kwargs["layout"] = previous_tensor.layout # type: ignore[assignment]
|
|
kwargs["requires_grad"] = False # type: ignore[assignment]
|
|
|
|
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined]
|
|
|
|
def __init__(
|
|
self,
|
|
original_tensor: Optional[torch.Tensor],
|
|
original_shape: Optional[torch.Size] = None,
|
|
compressed_tensor_cusparselt: Optional[torch.Tensor] = None,
|
|
sparse_tensor_cutlass: Optional[torch.Tensor] = None,
|
|
meta_tensor_cutlass: Optional[torch.Tensor] = None,
|
|
transposed: bool = False,
|
|
) -> None:
|
|
"""SparseSemiStructuredTensor constructor.
|
|
|
|
Args:
|
|
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
|
original_shape: The shape of the original dense tensor
|
|
compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata.
|
|
sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements.
|
|
meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata.
|
|
transposed: Whether the tensor is transposed or not.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
|
|
"""
|
|
# if original tensor is passed in, we need to compress it and store the compressed representation.
|
|
if original_tensor is not None:
|
|
# TODO right now we have unified checks and constraints for cuSPARSELt and CUTLASS, these are not actually the same.
|
|
# We should consolidate similar checks here and leave backend specific checks like shape in the op implementation.
|
|
|
|
# check device
|
|
if not original_tensor.is_cuda:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
|
"Only CUDA tensors are currently supported."
|
|
)
|
|
|
|
# check dim
|
|
if original_tensor.dim() != 2:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
|
"Only 2d tensors are currently supported."
|
|
)
|
|
|
|
# check dtype
|
|
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
|
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
|
|
)
|
|
|
|
# check shape
|
|
m, n = original_tensor.shape
|
|
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
|
|
original_tensor.dtype
|
|
].sparse_min_rows
|
|
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
|
|
original_tensor.dtype
|
|
].sparse_min_cols
|
|
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
|
|
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
|
|
raise RuntimeError(
|
|
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
|
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
|
)
|
|
|
|
compressed_tensor_cusparselt = None
|
|
sparse_tensor_cutlass = None
|
|
meta_tensor_cutlass = None
|
|
if self._FORCE_CUTLASS:
|
|
from torch.sparse._semi_structured_conversions import (
|
|
sparse_semi_structured_from_dense_cutlass,
|
|
)
|
|
|
|
sparse_tensor_cutlass, meta_tensor_cutlass = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
|
else:
|
|
# use cuSPARSELt
|
|
compressed_tensor_cusparselt = torch._cslt_compress(original_tensor)
|
|
|
|
# set values
|
|
self.original_tensor = None
|
|
self.compressed_tensor_cusparselt = compressed_tensor_cusparselt
|
|
self.sparse_tensor_cutlass = sparse_tensor_cutlass
|
|
self.meta_tensor_cutlass = meta_tensor_cutlass
|
|
self.transposed = transposed
|
|
self.original_shape = original_shape
|
|
|
|
def __tensor_flatten__(self):
|
|
if self.compressed_tensor_cusparselt is None:
|
|
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
|
else:
|
|
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
original_shape, transposed = meta
|
|
|
|
if len(inner_tensors) == 2:
|
|
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
|
|
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
|
|
compressed_tensor_cusparselt = None
|
|
elif len(inner_tensors) == 1:
|
|
sparse_tensor_cutlass = None
|
|
meta_tensor_cutlass = None
|
|
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
|
|
else:
|
|
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
|
|
|
|
return SparseSemiStructuredTensor(
|
|
None,
|
|
original_shape=original_shape,
|
|
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
|
|
sparse_tensor_cutlass=sparse_tensor_cutlass,
|
|
meta_tensor_cutlass=meta_tensor_cutlass,
|
|
transposed=transposed,
|
|
)
|
|
|
|
@staticmethod
|
|
def __get_indices_dtype(values_dtype):
|
|
if values_dtype == torch.int8:
|
|
return torch.int32
|
|
elif values_dtype in (torch.float16, torch.bfloat16, torch.float32):
|
|
return torch.int16
|
|
else:
|
|
raise RuntimeError(f"Datatype {values_dtype} is not supported!")
|
|
return None
|
|
|
|
def __repr__(self) -> str: # type: ignore[override]
|
|
"""Return string representation of SparseSemiStructuredTensor
|
|
|
|
Returns:
|
|
str: String representation
|
|
|
|
Raises:
|
|
None
|
|
"""
|
|
return (
|
|
f"SparseSemiStructuredTensor(shape={self.shape}, "
|
|
f"transposed={self.transposed}"
|
|
f"values={self.values()}"
|
|
f"metadata={self.indices()})"
|
|
)
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculates padding for dense tensor and pads tensor if necessary.
|
|
If padding is not required, this function returns the original tensor.
|
|
"""
|
|
# only 2d matmul
|
|
assert original_tensor.dim() == 2
|
|
|
|
# check shape
|
|
m, n = original_tensor.shape
|
|
min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_rows
|
|
min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_cols
|
|
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
|
|
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
|
|
if to_pad_m or to_pad_n:
|
|
return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m))
|
|
else:
|
|
return original_tensor
|
|
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
|
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
|
|
|
|
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
|
|
In the future we plan to also add in support for cuSPARSELt kernels.
|
|
|
|
Args:
|
|
func: The function being dispatched.
|
|
types: The types of the arguments.
|
|
args: The arguments passed to the function.
|
|
kwargs: The keyword arguments passed to the function.
|
|
|
|
Returns:
|
|
Any: The result of the dispatched operation.
|
|
|
|
Raises:
|
|
NotImplementedError: If the dispatched operation is not implemented.
|
|
"""
|
|
# Since this code runs below autograd, a detach corresponds to only returning a new object
|
|
if func is torch.ops.aten.detach.default:
|
|
return SparseSemiStructuredTensor(
|
|
args[0].original_tensor,
|
|
original_shape=args[0].shape,
|
|
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
|
|
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
|
|
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
|
|
transposed=args[0].transposed,
|
|
)
|
|
|
|
# Because we cannot go from the compressed representation back to the dense representation currently,
|
|
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
|
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
|
if func is torch.ops.aten.t.default:
|
|
return SparseSemiStructuredTensor(
|
|
args[0].original_tensor,
|
|
# transpose shape
|
|
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
|
|
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
|
|
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
|
|
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
|
|
transposed=not args[0].transposed,
|
|
)
|
|
|
|
# handle addmm
|
|
if func is torch.ops.aten.addmm.default:
|
|
bias, input_A, input_B = args
|
|
|
|
# Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
|
|
# CUTLASS only supports the first input to be sparse for a given matmul.
|
|
# cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
|
|
|
|
# We support second matrix sparse matmul by taking advantage of some transpose properties:
|
|
# This is also why we want an odd number of transposed for second matrix sparse vs an even number
|
|
# of transpose calss for first matrix sparse.
|
|
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
|
|
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
|
|
if isinstance(input_B, cls) and input_B.transposed:
|
|
row, col = input_A.shape
|
|
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
|
|
|
if input_B.compressed_tensor_cusparselt is None:
|
|
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
|
res = torch._sparse_semi_structured_linear(
|
|
input_A_padded,
|
|
input_B.sparse_tensor_cutlass,
|
|
input_B.meta_tensor_cutlass,
|
|
bias=bias
|
|
)
|
|
else:
|
|
res = torch._cslt_sparse_mm(
|
|
input_B.compressed_tensor_cusparselt,
|
|
input_A_padded.t(),
|
|
bias=bias, # type: ignore[arg-type]
|
|
transpose_result=cls._FUSE_TRANSPOSE
|
|
)
|
|
res = res if cls._FUSE_TRANSPOSE else res.t()
|
|
return res[:row, :]
|
|
|
|
# handle mm
|
|
if func is torch.ops.aten.mm.default:
|
|
input_A, input_B = args
|
|
|
|
# first element sparse
|
|
if isinstance(input_A, cls) and not input_A.transposed:
|
|
row, col = input_B.shape
|
|
input_B_padded = input_A._pad_tensor_for_matmul(input_B)
|
|
|
|
if input_A.compressed_tensor_cusparselt is None:
|
|
assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None
|
|
res = torch._sparse_semi_structured_linear(
|
|
input_B_padded.t(),
|
|
input_A.sparse_tensor_cutlass,
|
|
input_A.meta_tensor_cutlass
|
|
).t()
|
|
else:
|
|
res = torch._cslt_sparse_mm(
|
|
input_A.compressed_tensor_cusparselt,
|
|
input_B_padded,
|
|
bias=None # type: ignore[arg-type]
|
|
)
|
|
return res[:, :col]
|
|
|
|
# second element sparse
|
|
elif isinstance(input_B, cls) and input_B.transposed:
|
|
row, col = input_A.shape
|
|
input_A_padded = input_B._pad_tensor_for_matmul(input_A)
|
|
|
|
if input_B.compressed_tensor_cusparselt is None:
|
|
assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None
|
|
res = torch._sparse_semi_structured_linear(
|
|
input_A_padded,
|
|
input_B.sparse_tensor_cutlass,
|
|
input_B.meta_tensor_cutlass,
|
|
)
|
|
else:
|
|
res = torch._cslt_sparse_mm(
|
|
input_B.compressed_tensor_cusparselt,
|
|
input_A_padded.t(),
|
|
bias=None, # type: ignore[arg-type]
|
|
transpose_result=cls._FUSE_TRANSPOSE
|
|
)
|
|
res = res if cls._FUSE_TRANSPOSE else res.t()
|
|
return res[:row, :]
|
|
|
|
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
|
|
# so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul
|
|
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
|
|
if func is torch.ops.aten.linear.default:
|
|
input_tensor, weight, bias = args
|
|
# squash input_tensor to 2d
|
|
shape = input_tensor.shape
|
|
input_tensor_2d = input_tensor.view(-1, shape[-1])
|
|
row, col = input_tensor_2d.shape
|
|
# this is a noop if already padded
|
|
input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d)
|
|
|
|
if isinstance(weight, cls):
|
|
if weight.compressed_tensor_cusparselt is None:
|
|
assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None
|
|
res = torch._sparse_semi_structured_linear(
|
|
input_tensor_2d_padded,
|
|
weight.sparse_tensor_cutlass,
|
|
weight.meta_tensor_cutlass,
|
|
bias=bias
|
|
)
|
|
else:
|
|
res = torch._cslt_sparse_mm(
|
|
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
|
|
input_tensor_2d_padded.t(),
|
|
bias=bias,
|
|
transpose_result=cls._FUSE_TRANSPOSE
|
|
)
|
|
res = res if cls._FUSE_TRANSPOSE else res.t()
|
|
return res[:row, :].view(*shape[:-1], -1)
|
|
|
|
|
|
# handle values
|
|
if func is torch.ops.aten.values.default:
|
|
if args[0].compressed_tensor_cusparselt is None:
|
|
return args[0].sparse_tensor_cutlass.detach()
|
|
else:
|
|
m, k = args[0].shape
|
|
num_kept_elements = m * k // 2
|
|
return args[0].compressed_tensor_cusparselt[:num_kept_elements].view(m, k // 2)
|
|
|
|
# handle indices
|
|
if func is torch.ops.aten.indices.default:
|
|
if args[0].compressed_tensor_cusparselt is None:
|
|
return args[0].meta_tensor_cutlass
|
|
else:
|
|
m, k = args[0].shape
|
|
num_kept_elements = m * k // 2
|
|
metadata = args[0].compressed_tensor_cusparselt[num_kept_elements:].view(m, -1)
|
|
indices_dtype = SparseSemiStructuredTensor.__get_indices_dtype(
|
|
args[0].dtype
|
|
)
|
|
return metadata.view(indices_dtype)
|
|
|
|
error_string = "\n".join(
|
|
[f"func {func} with args: "]
|
|
+ [f"arg{i}: {arg}" for i, arg in enumerate(args)]
|
|
)
|
|
raise NotImplementedError(error_string)
|
|
|
|
|
|
def to_dense(self):
|
|
if self.compressed_tensor_cusparselt is not None:
|
|
raise RuntimeError("Converting to dense is not yet supported by cuSPARSELt backend!")
|
|
|
|
if self.sparse_tensor_cutlass is not None and self.sparse_tensor_cutlass.dtype == torch.float32:
|
|
raise RuntimeError("Converting to dense for torch.float32 datatype is not yet supported by CUTLASS backend!")
|
|
|
|
from torch.sparse._semi_structured_conversions import (
|
|
sparse_semi_structured_to_dense_cutlass,
|
|
)
|
|
|
|
return sparse_semi_structured_to_dense_cutlass(
|
|
self.sparse_tensor_cutlass,
|
|
self.meta_tensor_cutlass,
|
|
)
|
|
|
|
|
|
def to_sparse_semi_structured(
|
|
original_tensor: torch.Tensor,
|
|
transposed: bool = False,
|
|
) -> SparseSemiStructuredTensor:
|
|
"""
|
|
This function converts a dense tensor into a sparse semi-structured tensor.
|
|
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
|
|
|
|
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
|
|
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
|
|
Additionally, your tensor must be a positive multiple of a block size given the dtype
|
|
|
|
- torch.float16 (r, c) must be >= and a multiple of 64
|
|
- torch.int8 (r, c) must be >= and a multiple of 128
|
|
|
|
Args:
|
|
original_tensor (Tensor): the dense tensor to convert
|
|
transposed (bool, optional): whether the dense tensor is transposed
|
|
|
|
Returns:
|
|
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
|
|
|
|
Raises:
|
|
None
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
|
|
tensor([[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
...,
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
|
>>> A_sparse = to_sparse_semi_structured(A)
|
|
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
...,
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
|
|
metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
...,
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
|
|
dtype=torch.int16))
|
|
"""
|
|
return SparseSemiStructuredTensor(
|
|
original_tensor, original_shape=original_tensor.shape, transposed=transposed
|
|
)
|