pytorch/torch/sparse
Jesse Cai 16369816a2 [sparse] semi-structured sparse refactor (#117302)
Summary:

This PR is a refactor of semi-structured sparsity support.

**deprecation**:

Before `torch.sparse.to_sparse_semi_structured` had a kwarg param
`transposed=False`, which has been removed. This kwarg was unused and
now thros a deprecation warning.

Namely, I've taken the subclassing implementation that xFormers has
created and brought it over to PyTorch, as part of our plan to upstream
runtime 2:4 sparsity.

I've also copied over all the op support that Daniel implemenented that
did not depend on the fast sparsification routines, into
`_sparse_semi_structured_ops.py`

With this subclass, all of our internal tests pass, as well as those in
xFormers.

The main change is that we now define a base subclass,
`SparseSemiStructuredTensor` that is inherited from for each of the
specific backends.

We also now can arbitrarily override the sparse dispatch table with
`_load_dispatch_table()`, idea being this is still general enough
where users don't need to modify pytorch source code to get their model
working.

This also adds in padding support and stores alg_id and fuse_transpose
as flags on the tensor, instead of hardcoding them.

There still remains two components in xFormers that will need to be
ported over eventually:
- the autograd functions  (`Sparsify24`, `Sparsify24_like`)
- fast sparsification routines that they rely on

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302
Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
2024-02-14 01:10:40 +00:00
..
__init__.py [sparse] semi-structured sparse refactor (#117302) 2024-02-14 01:10:40 +00:00
_semi_structured_conversions.py Enable possibly-undefined error code (#118533) 2024-01-30 21:07:01 +00:00
_semi_structured_ops.py [sparse] semi-structured sparse refactor (#117302) 2024-02-14 01:10:40 +00:00
_triton_ops_meta.py Add instructions for generating optimal Triton kernel parameters of bsr_dense_addmm (#115504) 2023-12-12 16:44:51 +00:00
_triton_ops.py [SparseCsr] Remove triton sdpa skip after triton pin update (#109601) 2024-02-08 16:40:25 +00:00
semi_structured.py [sparse] semi-structured sparse refactor (#117302) 2024-02-14 01:10:40 +00:00