pytorch/torch/sparse
Jesse Cai c9db59e9e4 [sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
2024-04-19 13:31:58 +00:00
..
__init__.py Update DimOrDims typing in torch.sparse (#122471) 2024-03-25 16:25:56 +00:00
_semi_structured_conversions.py [sparse] Add fast semi-structured spasification kernels (#122350) 2024-04-19 13:31:58 +00:00
_semi_structured_ops.py [sparse] Add fast semi-structured spasification kernels (#122350) 2024-04-19 13:31:58 +00:00
_triton_ops_meta.py Update bsr_dense_addmm kernel parameters for sizes 3 x 2 ^ N (#122506) 2024-03-23 11:54:33 +00:00
_triton_ops.py Update bsr_dense_addmm kernel parameters for sizes 3 x 2 ^ N (#122506) 2024-03-23 11:54:33 +00:00
semi_structured.py [sparse] Add fast semi-structured spasification kernels (#122350) 2024-04-19 13:31:58 +00:00