pytorch/torch/sparse/_semi_structured_ops.py
Jesse Cai c9db59e9e4 [sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
2024-04-19 13:31:58 +00:00

167 lines
5.1 KiB
Python

import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",
"semi_sparse_indices",
"semi_sparse_t",
"semi_sparse_view",
"semi_sparse_detach",
"semi_sparse_mm",
"semi_sparse_addmm",
"semi_sparse_linear",
]
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def fallback_dispatcher(func, types, args, kwargs):
with no_dispatch():
return func(*args)
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
return A.packed[:num_kept_elements:].view(m, -1)
else:
return A.packed.detach()
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
metadata = A.packed[num_kept_elements:].view(m, -1)
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
else:
return A.meta
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
self = args[0]
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
assert len(self.shape) == 2
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,
meta=self.meta_t,
packed_t=self.packed,
meta_t=self.meta,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
if self.compressed_swizzled_bitmask is not None
else None,
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt,
)
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
self, shape = args
if tuple(shape) != self.shape:
raise NotImplementedError(
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
)
return self
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
assert len(args) == 1
self = args[0]
return self.__class__(
shape=self.shape,
packed=self.packed,
meta=self.meta,
packed_t=self.packed_t,
meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
requires_grad=False,
)
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
row, col = B.shape
B_padded = A._pad_dense_input(B)
res = A._mm(B_padded)
return res[:, :col]
else:
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B._pad_dense_input(A)
res = B_t._mm(A_padded.t()).t()
return res[:row, :]
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 3
bias, A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if bias.ndim != 1:
raise NotImplementedError(
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
)
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B_t._pad_dense_input(A)
result = B_t._mm(A_padded.t(), bias=bias).t()
return result[:row, :]
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) in [2, 3]
A, B = args[:2]
bias = args[2] if len(args) == 3 else None
shape = A.shape
A_2d = A.view(-1, shape[-1])
if bias is None:
res = A_2d @ B.t()
else:
res = semi_sparse_addmm(
func=None,
types=None,
args=[bias, A_2d, B.t()],
)
return res.view(*shape[:-1], -1)