import argparse import random import pandas as pd from tqdm import tqdm import torch import torch.utils.benchmark as benchmark from torch import nn from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured torch.set_printoptions( precision=2, threshold=None, edgeitems=16, linewidth=480, profile=None, sci_mode=False, ) # helper model definition for pruner class Model(nn.Module): def __init__(self, m, k, dtype=None): super().__init__() # transposed so reversed self.linear = nn.Linear(k, m) def forward(self, x): return self.linear(x) def rand_sparse_semi_structured_mask( r, c, dtype=torch.float16, device="cuda", choice=None ): """ This function returns a 1:2 sparse matrix of size (r, c). Note that this means this matrix will also be 2:4 and 4:8 sparse as well. """ choices = [[0, 1], [1, 0]] mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] return ( torch.tensor(mask_entries, dtype=dtype, device=device) .reshape(r, c) .contiguous() ) def test_linear(m, k, n, dtype, contiguous, backend): SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask input_tensor = torch.zeros(n, k).to(dtype).cuda() model = Model(m, k).to(dtype).cuda().eval() dense_measurement = benchmark.Timer( stmt="model(input_tensor)", globals=locals(), ).blocked_autorange() dense_output = model(input_tensor) print(dense_output.shape) # sparsify weights model.linear.weight = nn.Parameter( to_sparse_semi_structured( sparse_weight, ) ) sparse_output = model(input_tensor) print(sparse_output.shape) sparse_measurement = benchmark.Timer( stmt="model(input_tensor)", globals=locals(), ).blocked_autorange() correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) return { "test_function": "linear", "m": m, "k": k, "n": n, "dtype": str(dtype), "backend": backend, "sparse_latency (ms)": sparse_measurement.median * 1000, "dense_latency (ms)": dense_measurement.median * 1000, "speedup (d/s)": dense_measurement.median / sparse_measurement.median, "correct": correct, "contiguous": sparse_output.is_contiguous(), } def test_tensor(m, k, n, dtype, contiguous, backend): A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) B = torch.zeros(k, n).to(dtype).cuda() bias = torch.rand(n).to(dtype).cuda() sA = to_sparse_semi_structured(A) # torch.mm calculation if dtype is not torch.int8: dense_output = torch.mm(A, B) dense_measurement = benchmark.Timer( stmt="torch.mm(A, B)", globals=locals(), ).blocked_autorange() else: print("int8 baseline not supported") dense_output = torch.mm(sA, B) dense_measurement = benchmark.Timer( stmt="torch.mm(sA, B)", globals=locals(), ).blocked_autorange() sparse_output = torch.mm(sA, B) sparse_measurement = benchmark.Timer( stmt="torch.mm(sA, B)", globals=locals(), ).blocked_autorange() correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) return { "test_function": "tensor", "m": m, "k": k, "n": n, "dtype": str(dtype), "backend": backend, "sparse_latency (ms)": sparse_measurement.median * 1000, "dense_latency (ms)": dense_measurement.median * 1000, "speedup (d/s)": dense_measurement.median / sparse_measurement.median, "correct": correct, "contiguous": sparse_output.is_contiguous(), } if __name__ == "__main__": dtype_lookup = { "int8": torch.int8, "fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32, } parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") parser.add_argument( "--mode", type=str, choices=[ "nvidia-bert", "nvidia-fixed-k", "nvidia-fixed-mn", ], ) parser.add_argument( "--dtype", type=str, choices=dtype_lookup.keys(), default="fp16", ) parser.add_argument( "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" ) parser.add_argument("-contiguous", action="store_true") parser.add_argument("-e2e", action="store_true") parser.add_argument("-save", action="store_true") args = parser.parse_args() if args.e2e: eval_fn = test_linear else: eval_fn = test_tensor print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") dtype = dtype_lookup[args.dtype] if args.mode == "nvidia-bert": bert_shapes = [ (3072, 1024, 16384), (4096, 1024, 16384), (1024, 1024, 16384), (1024, 4096, 16384), ] results = ( eval_fn(m, k, n, dtype, args.contiguous, args.backend) for (m, k, n) in tqdm(bert_shapes) ) elif args.mode == "nvidia-fixed-k": mn_vals = [ 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408, 18432, 19456, 20480, ] results = ( eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) for mn in tqdm(mn_vals) ) elif args.mode == "nvidia-fixed-mn": k_vals = [ 2560, 3840, 5120, 6400, 7680, 8960, 10240, 11520, 12800, 14080, 15360, 16640, 17920, 19200, 20480, ] results = ( eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) for k in tqdm(k_vals) ) df = pd.DataFrame.from_records(results) if args.save: save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" df.to_csv(save_file) print(f"Finished benchmark: {args.mode} saved results to {save_file}") print(df)