import warnings from collections import namedtuple from typing import Any, Optional import torch __all__ = [ "SparseSemiStructuredTensor", "to_sparse_semi_structured", ] _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( "_SEMI_STRUCTURED_SPARSE_CONFIG", "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols" ) _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = { # torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16) for CUTLASS, cuSPASRELt has a 32 x 32 min sparse shape torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 128, 16, 16), torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4) } class SparseSemiStructuredTensor(torch.Tensor): """This class implementes semi-structured sparsity as a Tensor subclass. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained structured sparsity. Currently, this class supports 2:4 sparsity for int8, float16 and bfloat16 dtypes. We also support 1:2 sparsity for float32 dtype. This subclass stores the dense tensor in a compressed form by only storing the specified elements and corresponding metadata. The subclass supports two backend, either CUTLASS or cuSPASRELt. The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: compressed tensor = [ specified elements of original tensor | metadata ] For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements The rest of the tensor is metadata. For CUTLASS backend, elements of original tensor and metadata are kept in separate tensors. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear and sparse_semi_structured_from_dense for conversion to the compressed format. When PyTorch is compiled with cuSPARSELt support, this subclass will call into _cslt_sparse_mm for sparse mm and _cslt_compress to convert into the compressed format. """ _FUSE_TRANSPOSE = False _FORCE_CUTLASS = True _PROTOTYPE_WARNING_SHOWN = False @staticmethod def __new__( cls, original_tensor: Optional[torch.Tensor], original_shape: Optional[torch.Size] = None, compressed_tensor_cusparselt: Optional[torch.Tensor] = None, sparse_tensor_cutlass: Optional[torch.Tensor] = None, meta_tensor_cutlass: Optional[torch.Tensor] = None, transposed: bool = False, ): """ Create a new instance of the class. When original_tensor is passed in, we compress it and store the compresed representation. We can also create new instance of the class from the compressed representation without the original tensor. Args: original_tensor: The original dense tensor, or None, if we have already compressed the tensor. original_shape: The shape of the original dense tensor compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata. sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements. meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata. transposed: Whether the tensor is transposed or not. Returns: torch.Tensor: A torch.Tensor wrapper subclass. Raises: ValueError: If all of the tensor arguments are None. """ assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None) if not cls._PROTOTYPE_WARNING_SHOWN: warnings.warn( ( "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " "and will change in the near future. Please open a Github issue " "for features requests and see our documentation on the torch.sparse " "module for further information about the project." ), UserWarning, ) cls._PROTOTYPE_WARNING_SHOWN = True if original_tensor is not None: previous_tensor = original_tensor original_shape = original_tensor.shape elif compressed_tensor_cusparselt is not None: previous_tensor = compressed_tensor_cusparselt elif sparse_tensor_cutlass is not None: previous_tensor = sparse_tensor_cutlass else: raise ValueError("All of the tensor arguments are None!") kwargs = {} kwargs["device"] = previous_tensor.device # type: ignore[assignment] kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment] kwargs["layout"] = previous_tensor.layout # type: ignore[assignment] kwargs["requires_grad"] = False # type: ignore[assignment] return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined] def __init__( self, original_tensor: Optional[torch.Tensor], original_shape: Optional[torch.Size] = None, compressed_tensor_cusparselt: Optional[torch.Tensor] = None, sparse_tensor_cutlass: Optional[torch.Tensor] = None, meta_tensor_cutlass: Optional[torch.Tensor] = None, transposed: bool = False, ) -> None: """SparseSemiStructuredTensor constructor. Args: original_tensor: The original dense tensor, or None, if we have already compressed the tensor. original_shape: The shape of the original dense tensor compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata. sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements. meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata. transposed: Whether the tensor is transposed or not. Returns: None Raises: RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device. """ # if original tensor is passed in, we need to compress it and store the compressed representation. if original_tensor is not None: # TODO right now we have unified checks and constraints for cuSPARSELt and CUTLASS, these are not actually the same. # We should consolidate similar checks here and leave backend specific checks like shape in the op implementation. # check device if not original_tensor.is_cuda: raise RuntimeError( f"Error original_tensor.device= {original_tensor.device} is not supported! " "Only CUDA tensors are currently supported." ) # check dim if original_tensor.dim() != 2: raise RuntimeError( f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " "Only 2d tensors are currently supported." ) # check dtype if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG: raise RuntimeError( f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}" ) # check shape m, n = original_tensor.shape min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ original_tensor.dtype ].sparse_min_rows min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ original_tensor.dtype ].sparse_min_cols if m < min_rows or m % min_rows or n < min_cols or n % min_cols: # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples raise RuntimeError( f"Error original_tensor.shape {original_tensor.shape} is not supported! " f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" ) compressed_tensor_cusparselt = None sparse_tensor_cutlass = None meta_tensor_cutlass = None if self._FORCE_CUTLASS: from torch.sparse._semi_structured_conversions import ( sparse_semi_structured_from_dense_cutlass, ) sparse_tensor_cutlass, meta_tensor_cutlass = sparse_semi_structured_from_dense_cutlass(original_tensor) else: # use cuSPARSELt compressed_tensor_cusparselt = torch._cslt_compress(original_tensor) # set values self.original_tensor = None self.compressed_tensor_cusparselt = compressed_tensor_cusparselt self.sparse_tensor_cutlass = sparse_tensor_cutlass self.meta_tensor_cutlass = meta_tensor_cutlass self.transposed = transposed self.original_shape = original_shape def __tensor_flatten__(self): if self.compressed_tensor_cusparselt is None: return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed) else: return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed) @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): original_shape, transposed = meta if len(inner_tensors) == 2: sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass'] meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass'] compressed_tensor_cusparselt = None elif len(inner_tensors) == 1: sparse_tensor_cutlass = None meta_tensor_cutlass = None compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt'] else: raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}") return SparseSemiStructuredTensor( None, original_shape=original_shape, compressed_tensor_cusparselt=compressed_tensor_cusparselt, sparse_tensor_cutlass=sparse_tensor_cutlass, meta_tensor_cutlass=meta_tensor_cutlass, transposed=transposed, ) @staticmethod def __get_indices_dtype(values_dtype): if values_dtype == torch.int8: return torch.int32 elif values_dtype in (torch.float16, torch.bfloat16, torch.float32): return torch.int16 else: raise RuntimeError(f"Datatype {values_dtype} is not supported!") return None def __repr__(self) -> str: # type: ignore[override] """Return string representation of SparseSemiStructuredTensor Returns: str: String representation Raises: None """ return ( f"SparseSemiStructuredTensor(shape={self.shape}, " f"transposed={self.transposed}" f"values={self.values()}" f"metadata={self.indices()})" ) __torch_function__ = torch._C._disabled_torch_function_impl def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor: """ Calculates padding for dense tensor and pads tensor if necessary. If padding is not required, this function returns the original tensor. """ # only 2d matmul assert original_tensor.dim() == 2 # check shape m, n = original_tensor.shape min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_rows min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].dense_min_cols to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 if to_pad_m or to_pad_n: return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m)) else: return original_tensor @classmethod def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: """Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear. `torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels. In the future we plan to also add in support for cuSPARSELt kernels. Args: func: The function being dispatched. types: The types of the arguments. args: The arguments passed to the function. kwargs: The keyword arguments passed to the function. Returns: Any: The result of the dispatched operation. Raises: NotImplementedError: If the dispatched operation is not implemented. """ # Since this code runs below autograd, a detach corresponds to only returning a new object if func is torch.ops.aten.detach.default: return SparseSemiStructuredTensor( args[0].original_tensor, original_shape=args[0].shape, compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt, sparse_tensor_cutlass=args[0].sparse_tensor_cutlass, meta_tensor_cutlass=args[0].meta_tensor_cutlass, transposed=args[0].transposed, ) # Because we cannot go from the compressed representation back to the dense representation currently, # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix # is the first or second argument, we expect an even / odd number of calls to transpose respectively. if func is torch.ops.aten.t.default: return SparseSemiStructuredTensor( args[0].original_tensor, # transpose shape original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]), compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt, sparse_tensor_cutlass=args[0].sparse_tensor_cutlass, meta_tensor_cutlass=args[0].meta_tensor_cutlass, transposed=not args[0].transposed, ) # handle addmm if func is torch.ops.aten.addmm.default: bias, input_A, input_B = args # Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS. # CUTLASS only supports the first input to be sparse for a given matmul. # cuSPARSELt does not have this limitation, although our implementation is only for sparse first. # We support second matrix sparse matmul by taking advantage of some transpose properties: # This is also why we want an odd number of transposed for second matrix sparse vs an even number # of transpose calss for first matrix sparse. # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T if isinstance(input_B, cls) and input_B.transposed: row, col = input_A.shape input_A_padded = input_B._pad_tensor_for_matmul(input_A) if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None res = torch._sparse_semi_structured_linear( input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias ) else: res = torch._cslt_sparse_mm( input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=bias, # type: ignore[arg-type] transpose_result=cls._FUSE_TRANSPOSE ) res = res if cls._FUSE_TRANSPOSE else res.t() return res[:row, :] # handle mm if func is torch.ops.aten.mm.default: input_A, input_B = args # first element sparse if isinstance(input_A, cls) and not input_A.transposed: row, col = input_B.shape input_B_padded = input_A._pad_tensor_for_matmul(input_B) if input_A.compressed_tensor_cusparselt is None: assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None res = torch._sparse_semi_structured_linear( input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass ).t() else: res = torch._cslt_sparse_mm( input_A.compressed_tensor_cusparselt, input_B_padded, bias=None # type: ignore[arg-type] ) return res[:, :col] # second element sparse elif isinstance(input_B, cls) and input_B.transposed: row, col = input_A.shape input_A_padded = input_B._pad_tensor_for_matmul(input_A) if input_B.compressed_tensor_cusparselt is None: assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None res = torch._sparse_semi_structured_linear( input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, ) else: res = torch._cslt_sparse_mm( input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=None, # type: ignore[arg-type] transpose_result=cls._FUSE_TRANSPOSE ) res = res if cls._FUSE_TRANSPOSE else res.t() return res[:row, :] # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), # so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul # TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here. if func is torch.ops.aten.linear.default: input_tensor, weight, bias = args # squash input_tensor to 2d shape = input_tensor.shape input_tensor_2d = input_tensor.view(-1, shape[-1]) row, col = input_tensor_2d.shape # this is a noop if already padded input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d) if isinstance(weight, cls): if weight.compressed_tensor_cusparselt is None: assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None res = torch._sparse_semi_structured_linear( input_tensor_2d_padded, weight.sparse_tensor_cutlass, weight.meta_tensor_cutlass, bias=bias ) else: res = torch._cslt_sparse_mm( weight.compressed_tensor_cusparselt, # type: ignore[arg-type] input_tensor_2d_padded.t(), bias=bias, transpose_result=cls._FUSE_TRANSPOSE ) res = res if cls._FUSE_TRANSPOSE else res.t() return res[:row, :].view(*shape[:-1], -1) # handle values if func is torch.ops.aten.values.default: if args[0].compressed_tensor_cusparselt is None: return args[0].sparse_tensor_cutlass.detach() else: m, k = args[0].shape num_kept_elements = m * k // 2 return args[0].compressed_tensor_cusparselt[:num_kept_elements].view(m, k // 2) # handle indices if func is torch.ops.aten.indices.default: if args[0].compressed_tensor_cusparselt is None: return args[0].meta_tensor_cutlass else: m, k = args[0].shape num_kept_elements = m * k // 2 metadata = args[0].compressed_tensor_cusparselt[num_kept_elements:].view(m, -1) indices_dtype = SparseSemiStructuredTensor.__get_indices_dtype( args[0].dtype ) return metadata.view(indices_dtype) error_string = "\n".join( [f"func {func} with args: "] + [f"arg{i}: {arg}" for i, arg in enumerate(args)] ) raise NotImplementedError(error_string) def to_dense(self): if self.compressed_tensor_cusparselt is not None: raise RuntimeError("Converting to dense is not yet supported by cuSPARSELt backend!") if self.sparse_tensor_cutlass is not None and self.sparse_tensor_cutlass.dtype == torch.float32: raise RuntimeError("Converting to dense for torch.float32 datatype is not yet supported by CUTLASS backend!") from torch.sparse._semi_structured_conversions import ( sparse_semi_structured_to_dense_cutlass, ) return sparse_semi_structured_to_dense_cutlass( self.sparse_tensor_cutlass, self.meta_tensor_cutlass, ) def to_sparse_semi_structured( original_tensor: torch.Tensor, transposed: bool = False, ) -> SparseSemiStructuredTensor: """ This function converts a dense tensor into a sparse semi-structured tensor. It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. This function will check to ensure the dense tensor has the right dtype, size, dims, and device. We currently only support semi-structured sparse tensors for 2d CUDA tensors. Additionally, your tensor must be a positive multiple of a block size given the dtype - torch.float16 (r, c) must be >= and a multiple of 64 - torch.int8 (r, c) must be >= and a multiple of 128 Args: original_tensor (Tensor): the dense tensor to convert transposed (bool, optional): whether the dense tensor is transposed Returns: SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor Raises: None Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() tensor([[0., 0., 1., ..., 0., 1., 1.], [0., 0., 1., ..., 0., 1., 1.], [0., 0., 1., ..., 0., 1., 1.], ..., [0., 0., 1., ..., 0., 1., 1.], [0., 0., 1., ..., 0., 1., 1.], [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) >>> A_sparse = to_sparse_semi_structured(A) SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], ..., [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], [-4370, -4370, -4370, ..., -4370, -4370, -4370], [-4370, -4370, -4370, ..., -4370, -4370, -4370], ..., [-4370, -4370, -4370, ..., -4370, -4370, -4370], [-4370, -4370, -4370, ..., -4370, -4370, -4370], [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) """ return SparseSemiStructuredTensor( original_tensor, original_shape=original_tensor.shape, transposed=transposed )