pytorch/torch/sparse/_semi_structured_conversions.py
Jesse Cai a8e354a9a0 [sparse][semi-structured] enable fp32 support, separate sparse and dense constraints (#115550)
Summary:

Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for
fp32, which this PR enables.(thanks @alexsamardzic).

Furthermore, this PR also updates the sparse_config to take into account
the different shape constraints for sparse and dense matrices.

Technically, cuSPARSELt supports smaller sparse matrix constraints as it
seens to pad to the CUTLASS constraints under the hood. However, in
practice small sparse matrices are not commonly used and we care more
about the dense constraints for LLM inference.

For now, we keep the CUTLASS constraints in place for both cuSPARSELt
and CUTLASS tensors

This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550
Approved by: https://github.com/cpuhrsch
2023-12-15 02:28:17 +00:00

364 lines
14 KiB
Python

import torch
def _sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
)
m, k = dense.shape
device = dense.device
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 16"
)
else:
if m % 32 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 32"
)
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
if dense.dtype != torch.float32:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1101
# [False, True, True, True ] -> 0b1001
# [True, False, False, False] -> 0b1100
# [True, False, True, True ] -> 0b1000
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b1000
# [True, True, True, True ] -> 0b1000
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits.
bit0 = ~m0 & m1
bit1 = ~m0 & ~m1
bit2 = bit1 | ~m2
bit3 = bit0 | ~m1 | m2
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float32:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
)
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28)
)
# Metadata values are now to be reshuffled in a way given in
# reorder_meta() function, in
# tools/util/include/cutlass/util/host_reorder.h file of CUTLASS
# source tree. Furthermore, CUTLASS template for sparse GEMM
# decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described
# by ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is given below.
# However, this calculation produces offsets for scatter access
# from metadata matrix to redordered metadata matrix, and gather
# pattern is more efficient. For this reason, the scatter offsets
# are reverted and printed, through enabling commented block at
# the end of following code. Resulting gather offsets are then
# analyzed, on several (m, k) value pairs (in particular: (32,
# 128), (32, 256), (64, 128) and (64, 256)), and the code that
# follows this comment is written to reproduce these gather offsets.
#
# dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
# dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
#
# # Reorder the rows, then swizzle the 2x2 blocks.
# group = 32 if meta_dtype.itemsize == 2 else 16
# interweave = 4 if meta_dtype.itemsize == 2 else 2
# dst_rows = (
# dst_rows // group * group
# + (dst_rows % 8) * interweave
# + (dst_rows % group) // 8
# )
#
# topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
# bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
# dst_rows += topright - bottomleft
# dst_cols -= topright - bottomleft
#
# # Assumed that meta tensor is to be stored in CUTLASS
# # InterleavedColumnMajor layout, and reverse engineered
# # corresponding code to store values into this tensor.
# interleave = 2
# cols_maj = dst_cols // interleave
# cols_min = dst_cols % interleave
# meta_reordered_offsets = (
# cols_maj * m * interleave + dst_rows * interleave + cols_min
# )
#
# meta_reordered = torch.empty((m, meta_ncols), dtype=meta_dtype, device=device)
# meta_reordered.view(-1)[meta_reordered_offsets.view(-1)] = meta.view(-1)
#
# # Uncomment to have gather pattern for meta_reordered printed
# #
# #offsets = torch.empty(
# # (m, meta_ncols), dtype=meta_reordered_offsets.dtype, device=device
# #)
# #offsets.view(-1)[meta_reordered_offsets.view(-1)] = torch.arange(
# # 0, m * meta_ncols, dtype=meta_reordered_offsets.dtype, device=device
# #)
# #torch.set_printoptions(threshold=1000000)
# #print("------------------------------------------------------------")
# #print("dtype =", dtype, ", m =", m, ", k =", k, ", meta_ncols =", meta_ncols)
# #print(offsets.view(-1))
#
# No point to try to understand this code: as mentioned in the
# comment above it is written to reproduce gather offsets, as
# these would be calculated by CUTLASS, and to be efficient, but
# it contains several magic values and magic calculations that
# make it rather hard to read, let alone understand.
if meta_dtype == torch.int32:
magic0 = 4
magic1 = 32
magic2 = 16
magic3 = 2 * k // ksparse
magic4 = [0, k // ksparse, 1, k // ksparse + 1]
elif meta_dtype == torch.int16:
magic0 = 8
magic1 = 64
magic2 = 32
magic3 = 4 * 2 * k // ksparse
magic4 = [0, 2 * k // ksparse, 1, 2 * k // ksparse + 1, 2 * 2 * k // ksparse,
3 * 2 * k // ksparse, 2 * 2 * k // ksparse + 1, 3 * 2 * k // ksparse + 1]
tmp0 = torch.zeros(m * meta_ncols, dtype=torch.int64, device=device)
tmp1 = (
tmp0.view(meta_ncols // 2, -1)
+ torch.arange(0, meta_ncols, 2, device=device).view(meta_ncols // 2, 1)
).view(-1, magic1)
tmp2 = (
(
torch.arange(0, 8, device=device).view(-1, 1)
* torch.ones((magic0,), dtype=torch.int64, device=device)
* meta_ncols
)
.view(-1)
.repeat(m * meta_ncols // magic1)
.view(-1, magic1)
)
tmp3 = (torch.arange(0, m // magic2, device=device).view(-1, 1) * magic3).repeat(
meta_ncols // 2, magic1
)
tmp4 = torch.tensor(magic4, device=device).repeat(tmp3.shape[0], 8)
meta_offsets = tmp1 + tmp2 + tmp3 + tmp4
meta_reordered = torch.gather(meta.view(-1), 0, meta_offsets.view(-1)).view(
m, meta_ncols
)
return (sparse, meta_reordered)
def _sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
)
m, k = sparse.shape
device = sparse.device
if meta_reordered.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
)
if meta_reordered.device != device:
raise RuntimeError(
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
)
meta_dtype = meta_reordered.dtype
if meta_dtype not in (torch.int16, torch.int32):
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
meta_nrows, meta_ncols = meta_reordered.shape
if meta_nrows != m:
raise RuntimeError(
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
)
if meta_ncols * 4 * quadbits_per_meta_elem != 2 * k:
raise RuntimeError(
f"Number of columns of sparse matrix {k} different from the {meta_ncols * 4 * quadbits_per_meta_elem // 2}, "
"expected according to the number of columns of meta matrix"
)
if meta_dtype == torch.int32:
magic0 = 4
magic1 = [0, 1, 32, 33]
elif meta_dtype == torch.int16:
magic0 = 8
magic1 = [0, 1, 4, 5]
tmp1 = torch.tensor([0, 2], dtype=torch.int64, device=device).repeat(
meta_nrows, meta_ncols // 2
)
tmp2 = (
(torch.arange(0, meta_ncols // 2, device=device) * 2 * meta_nrows)
.view(-1, 1)
.repeat(1, 2)
.view(-1)
.repeat(m, 1)
)
tmp3 = (
(torch.arange(0, 8, device=device) * magic0)
.view(-1, 1)
.repeat(m // 8, meta_ncols)
)
tmp4 = (
torch.tensor(magic1, device=device)
.view(-1, 1)
.repeat(1, 8 * meta_ncols)
.repeat(meta_nrows // 32, 1)
.view(meta_nrows, meta_ncols)
)
tmp5 = (
(torch.arange(0, meta_nrows // 32, device=device) * 64)
.view(-1, 1)
.repeat(1, 32 * meta_ncols)
.view(meta_nrows, meta_ncols)
)
meta_offsets = tmp1 + tmp2 + tmp3 + tmp4 + tmp5
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets.view(-1)).view(
m, meta_ncols
)
meta_2 = torch.empty(
(m, meta_ncols, 2 * quadbits_per_meta_elem), dtype=meta_dtype, device=device
)
if quadbits_per_meta_elem == 4:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
elif quadbits_per_meta_elem == 8:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
meta_2[:, :, 8] = (meta >> 16) & 0b11
meta_2[:, :, 9] = (meta >> 18) & 0b11
meta_2[:, :, 10] = (meta >> 20) & 0b11
meta_2[:, :, 11] = (meta >> 22) & 0b11
meta_2[:, :, 12] = (meta >> 24) & 0b11
meta_2[:, :, 13] = (meta >> 26) & 0b11
meta_2[:, :, 14] = (meta >> 28) & 0b11
meta_2[:, :, 15] = (meta >> 30) & 0b11
dense_offsets = meta_2.view(-1) + (
torch.arange(0, m * k // 2, device=device) * 4
).view(-1, 1).repeat(1, 2).view(-1)
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
dense.scatter_(0, dense_offsets, sparse.view(-1))
return dense.view(m, 2 * k)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def sparse_semi_structured_from_dense_cutlass(dense, compile=False):
if compile:
from torch._dynamo.utils import is_compile_supported
if is_compile_supported(dense.device.type):
kernel = torch.compile(_sparse_semi_structured_from_dense_cutlass)
return kernel(dense)
return _sparse_semi_structured_from_dense_cutlass(dense)
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered, compile=False):
if compile:
from torch._dynamo.utils import is_compile_supported
if is_compile_supported(sparse.device.type):
kernel = torch.compile(_sparse_semi_structured_to_dense_cutlass)
return kernel(sparse, meta_reordered)
return _sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered)