mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105595 Approved by: https://github.com/jcaip
420 lines
17 KiB
Python
420 lines
17 KiB
Python
# Owner(s): ["module: sparse"]
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from torch.sparse.semi_structured import (
|
|
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
|
|
SparseSemiStructuredTensor,
|
|
to_sparse_semi_structured,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
)
|
|
|
|
from torch.testing._internal.common_dtype import all_types_and_complex
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
TestCase,
|
|
TEST_WITH_ROCM
|
|
)
|
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
|
|
|
_IS_SM8X = False
|
|
if torch.cuda.is_available():
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
|
|
|
|
# check if cslt is available for now using this:
|
|
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
|
try:
|
|
torch._cslt_compress(torch.ones(128, 128).cuda())
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
|
|
def rand_sparse_semi_structured_mask(
|
|
r, c, dtype=torch.float16, device="cuda", choice=None
|
|
):
|
|
"""
|
|
This function returns a 1:2 sparse matrix of size (r, c).
|
|
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
|
"""
|
|
|
|
choices = [[0, 1], [1, 0]]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
|
|
return (
|
|
torch.tensor(mask_entries, dtype=dtype, device=device)
|
|
.reshape(r, c)
|
|
.contiguous()
|
|
)
|
|
|
|
def rand_dense_2by4(r, c, dtype, device, choice=None):
|
|
choices = [
|
|
[1, 1, 0, 0],
|
|
[1, 0, 1, 0],
|
|
[1, 0, 0, 1],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 0, 1],
|
|
[0, 0, 1, 1]
|
|
]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 4)]
|
|
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
|
|
dense = dense.masked_fill(~mask, 0)
|
|
return dense
|
|
|
|
def rand_dense_2by4_all_patterns(r, c, dtype, device):
|
|
choices = [
|
|
[[0, 0, 0, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 0, 1], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 1], [0, 0, 1, 1]],
|
|
[[0, 1, 0, 0], [0, 1, 0, 1]],
|
|
[[0, 1, 0, 1], [0, 1, 0, 1]],
|
|
[[0, 1, 1, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 1, 1], [0, 1, 1, 0]],
|
|
[[1, 0, 0, 0], [1, 0, 0, 1]],
|
|
[[1, 0, 0, 1], [1, 0, 0, 1]],
|
|
[[1, 0, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 1, 1], [1, 0, 1, 0]],
|
|
[[1, 1, 0, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 0, 1], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 1, 1, 1], [1, 0, 1, 0]],
|
|
]
|
|
COL_INV, COL_VAL = 0, 1
|
|
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // 4)]
|
|
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
|
|
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
|
|
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
|
|
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
|
|
dense_inv = dense.masked_fill(~mask_inv, 0)
|
|
dense_val = dense_inv.masked_fill(~mask_val, 0)
|
|
return dense_inv, dense_val
|
|
|
|
|
|
class TestSparseSemiStructured(TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_to_sparse_semi_structured(self, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
assert A.shape == A_sparse.shape
|
|
assert A.device == A_sparse.device
|
|
assert A.dtype == A_sparse.dtype
|
|
|
|
assert isinstance(A, torch.Tensor)
|
|
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
|
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_first_NT(self, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
|
Ensure torch.mm(A_sparse, B.t()) is correct
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
# This should fail
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
|
|
# test transpose
|
|
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
|
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
dense_result = torch.mm(A, B)
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
# test transpose
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_first_T(self, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse.t(), B) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
|
|
):
|
|
torch.mm(A_sparse.t(), B)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_second_T(self, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse.t()) is correct
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
else:
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
def test_cslt_sparse_mm_int8_in_fp16_out(self, device):
|
|
"""
|
|
This test is only needed for cuSPARSELt
|
|
"""
|
|
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand((128, 128), device=A_sparse.device).to(torch.int8)
|
|
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.float16)
|
|
sparse_result = torch._cslt_sparse_mm(A_sparse.compressed_tensor_cusparselt, B.t(), out_dtype=torch.float16)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_second_NT(self, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
|
|
):
|
|
sparse_result = torch.mm(A, B_sparse)
|
|
|
|
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_linear(self, inference_mode, device, backend):
|
|
"""
|
|
Test nn.Linear has the same numerics
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
input = torch.rand(64, 128, 128, device=device).half()
|
|
model = nn.Linear(128, 128).to(device).half()
|
|
m, n = model.weight.shape
|
|
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
|
# set masked weight
|
|
model.weight = nn.Parameter(model.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
|
|
|
|
if inference_mode:
|
|
with torch.inference_mode():
|
|
sparse_result = model(input)
|
|
else:
|
|
sparse_result = model(input)
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_values(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.values().shape == (128, 64)
|
|
assert (A_sparse.values() == 1).all()
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_indices(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.indices().shape == (128, 8)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device)
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@dtypes(*all_types_and_complex())
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_dtype(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
|
|
|
|
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
else:
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_dim(self, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_linear_cutlass(self, device, dtype, backend):
|
|
if dtype is not torch.float32:
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
|
weight = rand_dense_2by4(m, k, dtype, device)
|
|
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
|
|
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
|
|
|
|
dtype_dense = torch.float
|
|
input_dense = input.to(dtype_dense)
|
|
weight_dense = weight.to(dtype_dense)
|
|
bias_dense = bias.to(dtype_dense) if add_bias else None
|
|
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output0 = relu(output0)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output0 = silu(output0)
|
|
|
|
weight_sparse = weight.masked_select(weight != 0).view(m, k // 2)
|
|
|
|
meta = to_sparse_semi_structured(weight).indices()
|
|
|
|
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
batch_shapes = [[], [3], [3, 1]]
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16}
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
for batch_shape, m, n, k, add_bias, activation in \
|
|
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
|
|
if activation == "silu" and dtype == torch.int8:
|
|
continue # SiLU not supported for integer inputs
|
|
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_conversions(self, device, dtype, backend):
|
|
if dtype is not torch.float32:
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
def run_test(r, c, device, dtype):
|
|
dense_ref = rand_dense_2by4(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_ref)
|
|
|
|
# The torch.ops.aten._to_sparse_semi_structured operator
|
|
# uses CUTLASS to perform conversion from given dense
|
|
# matrix to the pair of corresponding sparse and metadata
|
|
# matrices, with the later used here as a reference to
|
|
# compare the metadata matrix produced by conversion
|
|
# performed by SparseSemiStructuredTensor class
|
|
# constructor against.
|
|
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
|
|
meta = compressed.indices()
|
|
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
|
|
|
|
dense = compressed.to_dense()
|
|
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
|
|
|
|
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
|
|
for r, c in shapes:
|
|
run_test(r, c, device, dtype)
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_conversions_all_patterns(self, device, dtype, backend):
|
|
if dtype is not torch.float32:
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
r, c = 32, 128
|
|
|
|
dense_inv, dense_val = rand_dense_2by4_all_patterns(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_inv)
|
|
dense = compressed.to_dense()
|
|
|
|
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
|
|
|
|
|
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|