From 89f1ad08b43cbbe7d7d0629d899b9e088c30478f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 17 Jan 2023 22:14:37 +0000 Subject: [PATCH] Revert "Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078)" This reverts commit 7f256fff77c49729131aa6d092e60e891d0c4948. Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/huydhn due to This breaks lint https://hud.pytorch.org/pytorch/pytorch/commit/7f256fff77c49729131aa6d092e60e891d0c4948 --- aten/src/ATen/native/native_functions.yaml | 6 - .../src/ATen/native/sparse/SparseBlasImpl.cpp | 30 - .../native/sparse/SparseCsrTensorMath.cpp | 7 - aten/src/ATen/native/sparse/SparseMatMul.cpp | 1 + mypy.ini | 3 - test/test_sparse_csr.py | 58 -- torch/__init__.py | 4 - torch/sparse/__init__.py | 29 - torch/sparse/_triton_ops.py | 608 ------------------ 9 files changed, 1 insertion(+), 745 deletions(-) delete mode 100644 torch/sparse/_triton_ops.py diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 80edd6421f3..84d71c31754 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6446,12 +6446,6 @@ SparseCPU: s_addmm_sparse_dense_cpu_ SparseCUDA: s_addmm_sparse_dense_cuda_ -- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor - variants: function - dispatch: - CPU: triton_bsr_dense_mm - autogen: _triton_bsr_dense_mm.out - - func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index c147e8c7090..cdeb3e134e5 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -4,10 +4,6 @@ #include #include -// Required for checking whether Triton kernels are available -#include -#include - #ifndef AT_PER_OPERATOR_HEADERS #include #include @@ -16,7 +12,6 @@ #include #include #include -#include #endif namespace at { @@ -75,31 +70,6 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s blocksize = {values.size(-2), values.size(-1)}; } -// No stable support for ROCM in Triton yet. -#ifndef USE_ROCM - // Triton works only with blocksizes which are powers of 2. - const auto is_power_of_2 = [](int64_t v) -> bool { - return !(v & (v - 1)); - }; - - // Dtype and blocksize checks for potential Triton usage. - if ((strided.scalar_type() == ScalarType::Half - || strided.scalar_type() == ScalarType::BFloat16) - && is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1]) - && (blocksize[0] >= 16) && (blocksize[1] >= 16) - // lhs is retiled to (b0, b1) while rhs is to (b1, b0), - // so the result is tiled to (b0, b0) and we need to make - // sure that dense.size(-1) is divisible by b0. - && n % blocksize[0] == 0) { - const auto triton_kernel = c10::Dispatcher::singleton() - .findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm")); - // Call Triton only if dispatch key was overwritten. - if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) { - return at::_triton_bsr_dense_mm_out(result, compressed, strided); - } - } -#endif - // (..., r, c) -> (..., r / b0, c / b1, b0, b1) // NOTE: this function ALWAYS creates a view upon successful execution. const auto tile_tensor = [compressed_layout]( diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index f407b7bb641..efa692665d4 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -1292,12 +1292,5 @@ Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, boo return result; } -Tensor triton_bsr_dense_mm( - const Tensor& bsr, - const Tensor& dense) { - TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python."); - return Tensor {}; -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/SparseMatMul.cpp b/aten/src/ATen/native/sparse/SparseMatMul.cpp index e5f283bd452..548b66ae46d 100644 --- a/aten/src/ATen/native/sparse/SparseMatMul.cpp +++ b/aten/src/ATen/native/sparse/SparseMatMul.cpp @@ -274,5 +274,6 @@ Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) { return output; } + } // namespace native } // namespace at diff --git a/mypy.ini b/mypy.ini index 7108feea21d..4afe7dcf125 100644 --- a/mypy.ini +++ b/mypy.ini @@ -188,9 +188,6 @@ ignore_errors = True # Third party dependencies that don't have types. # -[mypy-triton.*] -ignore_missing_imports = True - [mypy-tensorflow.*] ignore_missing_imports = True diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index b1bfe598086..30606d15b85 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -20,7 +20,6 @@ from torch.testing._internal.common_dtype import ( floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and, all_types_and_complex, floating_and_complex_types_and ) -from torch._inductor.utils import has_triton from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED if TEST_SCIPY: @@ -1465,63 +1464,6 @@ class TestSparseCSR(TestCase): self.assertEqual(actual, out) self.assertEqual(actual, expected) - @parametrize("block_size", [16, 32, 64]) - @parametrize("index_dtype", [torch.int32, torch.int64]) - @unittest.skipIf(not has_triton(), "Triton is not available") - @skipCUDAIfRocm - @onlyCUDA - @dtypes(torch.half, torch.bfloat16) - @dtypesIfCUDA(*[torch.half] if SM53OrLater else [], - *[torch.bfloat16] if SM80OrLater else []) - def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): - from functools import partial - - # Note that each value in a non-zero block is in range block_size * [low^2, high^2). - tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5) - - # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. - batches = [(), (2,)] - size = [128, 256, 0] - - # Whether to make inputs orthogonal so that the product is zero - make_orthogonal = [True, False] - - for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal): - bsr = tensor(bs + (m, k)) - # NOTE: do not get confused, it will be transposed - dense = tensor(bd + (n, k)) - - if is_ortho: - bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1) - dense = torch.cat((torch.zeros_like(dense), dense), dim=-1) - - bsr = bsr.to_sparse_bsr(block_size) - - if bsr.dim() == 2: - # Test against linear to check dispatch. - res_tri = torch.nn.functional.linear(dense, bsr) - res_dense = torch.nn.functional.linear(dense, bsr.to_dense()) - else: - # Otherwise check correctness against bmm - # since nn.linear does not support bsr.dim() > 2. - res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1)) - res_dense = bsr.to_dense() @ dense.transpose(-2, -1) - self.assertEqual(res_tri, res_dense) - - res_dense = bsr.to_dense() @ dense.transpose(-2, -1) - # check whether bsr_dense_mm handles different grid sizes - # None means max possible grid size which is CUDA-dependent. - grid_size = (None, 2, 4) - grid_gen = itertools.product(grid_size, repeat=3) - for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen): - res_tri = torch.sparse._triton_ops.bsr_dense_mm( - bsr, - dense.transpose(-2, -1), - max_grid=grid, - is_sparse_rowspace_mode=is_sparse_rowspace - ) - self.assertEqual(res_tri, res_dense) - # TODO: block_size 1 is broken @parametrize("block_size", [2, 3]) @parametrize("index_dtype", [torch.int32, torch.int64]) diff --git a/torch/__init__.py b/torch/__init__.py index 90f1972ad19..5810b44cd51 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1328,7 +1328,3 @@ import torch.fx.experimental.symbolic_shapes from torch import func as func from torch.func import vmap - -# dynamic registration of sparse triton kernels -from torch.sparse import _register_impls -_register_impls(torch.library.Library("aten", "IMPL")) diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index a7a909e5f7b..3ceaf56fc20 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -4,8 +4,6 @@ from typing import Optional, Tuple, List, Union import torch from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] from torch import Tensor -from torch.cuda import _lazy_call -from torch._inductor.cuda_properties import get_device_capability # A workaround to support both TorchScript and MyPy: from typing import TYPE_CHECKING @@ -464,30 +462,3 @@ See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more informat return mth(*args, **kwargs) return test_mth - -# Triton registrations -def _has_triton(): - if not torch.cuda.is_available(): - return False - try: - import triton - - return triton is not None and get_device_capability() >= (7, 0) - except ImportError: - return False - - -def _register_impls(lib): - """This function is called from torch/__init__.py to do any dynamic registrations. """ - - - def register_sparse_cuda_impls(lib=lib): - from ._triton_ops import bsr_dense_mm - - if bsr_dense_mm is not None: - lib.impl("aten::_triton_bsr_dense_mm", - lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs), "SparseCsrCUDA") - - # This code is evaluated on import torch and therefore cannot force initialization of the cuda rt - # We must schedule the registration to occur lazily. - _lazy_call(register_sparse_cuda_impls) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py deleted file mode 100644 index d7b34f34905..00000000000 --- a/torch/sparse/_triton_ops.py +++ /dev/null @@ -1,608 +0,0 @@ -import torch -from torch._inductor.cuda_properties import get_device_capability - -def _has_triton(): - if not torch.cuda.is_available(): - return False - try: - import triton - - return triton is not None and get_device_capability() >= (7, 0) - except ImportError: - return False - -def compressed_indices_to_plain_indices(cidx, pidx): - nnz = pidx.shape[-1] - cdim = cidx.shape[-1] - 1 - batch_numel = cidx.shape[0] - batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[ - :, None - ] - - cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset - cidx_linear = torch.empty( - (batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device - ) - cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1) - cidx_linear[-1] = nnz * batch_numel - - idx_linear = torch._convert_indices_from_csr_to_coo( - cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32) - ).select(0, 0) - - return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset) - - -def slicer(dim, slice_range, *tensors): - for t in tensors: - slices = [slice(None)] * t.dim() - slices[dim] = slice_range - yield t[slices] - -if _has_triton(): - import triton - import triton.language as tl - from typing import Optional, Tuple - - @triton.jit - def _bsr_strided_dense_rowspace_kernel( - BLOCKSIZE_ROW: tl.constexpr, - BLOCKSIZE_COL: tl.constexpr, - # values prologue - values_ptr, - values_batch_stride, - values_nnz_stride, - values_row_block_stride, - values_col_block_stride, - # values epilogue - # crow_indices prologue - crow_indices_ptr, - crow_indices_batch_stride, - crow_indices_stride, - # crow_indices epilogue - # col_indices prologue - col_indices_ptr, - col_indices_batch_stride, - col_indices_stride, - # col_indices epilogue - # dense prologue - dense_ptr, - dense_batch_stride, - dense_tiled_row_stride, - dense_tiled_col_stride, - dense_row_block_stride, - dense_col_block_stride, - # dense epilogue - # output prologue - output_ptr, - output_batch_stride, - output_tiled_row_stride, - output_tiled_col_stride, - output_row_block_stride, - output_col_block_stride, - # output epilogue - GROUP_SIZE_ROW: tl.constexpr, - ): - batch_pid = tl.program_id(axis=2) - row_block_pid = tl.program_id(axis=0) - col_block_pid = tl.program_id(axis=1) - n_block_rows = tl.num_programs(axis=0) - n_block_cols = tl.num_programs(axis=1) - - row_block_pid, col_block_pid = tl.swizzle2d( - row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW - ) - - crow_indices_offset_ptr = ( - crow_indices_ptr - + crow_indices_batch_stride * batch_pid - + crow_indices_stride * row_block_pid - ) - nnz_offset = tl.load(crow_indices_offset_ptr) - nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) - - # Compute nnz for the row with number row_block_pid. - # If it is zero, skip the row. - row_nnz = nnz_offset_next - nnz_offset - if row_nnz == 0: - return - - row_block_arange = tl.arange(0, BLOCKSIZE_ROW) - col_block_arange = tl.arange(0, BLOCKSIZE_COL) - - # Pointers are set to the first block of the current row. - values_block_ptrs = ( - values_ptr - + values_batch_stride * batch_pid - + values_nnz_stride * nnz_offset - + values_row_block_stride * row_block_arange[:, None] - + values_col_block_stride * col_block_arange[None, :] - ) - - # NOTE: dense is advanced into all dimensions but the tiled row one. - # That will be advanced in the loop according to values in col_indices. - dense_block_ptrs = ( - dense_ptr - + dense_batch_stride * batch_pid - + dense_tiled_col_stride * col_block_pid - + dense_row_block_stride * col_block_arange[:, None] - + dense_col_block_stride * row_block_arange[None, :] - ) - - # Pointers are set to exact write-to locations - output_ptrs = ( - output_ptr - + output_batch_stride * batch_pid - + output_tiled_row_stride * row_block_pid - + output_tiled_col_stride * col_block_pid - + output_row_block_stride * row_block_arange[:, None] - + output_col_block_stride * row_block_arange[None, :] - ) - - # Set pointer to the first nonzero element in the current row - col_index_nnz_ptr = ( - col_indices_ptr - + col_indices_batch_stride * batch_pid - + col_indices_stride * nnz_offset - ) - - output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32) - for _ in range(row_nnz): - values_block = tl.load(values_block_ptrs) - - # find which row of dense needs to get loaded - # for multiplication with values_block. - dense_row_idx = tl.load(col_index_nnz_ptr) - dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) - - # do block mm - output_acc_block += tl.dot(values_block, dense_block) - - # move val/col_index ptrs to the next block in the row - values_block_ptrs += values_nnz_stride - col_index_nnz_ptr += col_indices_stride - - # write back the result - tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) - - - @triton.jit - def _bsr_strided_sparse_rowspace_kernel( - BLOCKSIZE_ROW: tl.constexpr, - BLOCKSIZE_COL: tl.constexpr, - batch_idx_ptr, - row_idx_ptr, - nnz_per_row_ptr, - nnz_per_row_cumsum_ptr, - col_indices_ptr, - col_indices_stride, - # values prologue - values_ptr, - values_nnz_stride, - values_row_block_stride, - values_col_block_stride, - # values epilogue - # dense prologue - dense_ptr, - dense_batch_stride, - dense_tiled_row_stride, - dense_tiled_col_stride, - dense_row_block_stride, - dense_col_block_stride, - # dense epilogue - # output prologue - output_ptr, - output_batch_stride, - output_tiled_row_stride, - output_tiled_col_stride, - output_row_block_stride, - output_col_block_stride, - # output epilogue - GROUP_SIZE_ROW: tl.constexpr, - ): - row_block_pid = tl.program_id(axis=0) - col_block_pid = tl.program_id(axis=1) - n_block_rows = tl.num_programs(axis=0) - n_block_cols = tl.num_programs(axis=1) - - row_block_pid, col_block_pid = tl.swizzle2d( - row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW - ) - - batch_idx = tl.load(batch_idx_ptr + row_block_pid) - row_idx = tl.load(row_idx_ptr + row_block_pid) - row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid) - row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid) - row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz - - row_block_arange = tl.arange(0, BLOCKSIZE_ROW) - col_block_arange = tl.arange(0, BLOCKSIZE_COL) - - # Pointers are set to the first block of the current row. - values_block_ptrs = ( - values_ptr - + values_nnz_stride * row_idx_nnz_offset - + values_row_block_stride * row_block_arange[:, None] - + values_col_block_stride * col_block_arange[None, :] - ) - - # NOTE: dense is advanced into all dimensions but the tiled row one. - # That will be advanced in the loop according to values in col_indices. - dense_block_ptrs = ( - dense_ptr - + dense_batch_stride * batch_idx - + dense_tiled_col_stride * col_block_pid - + dense_row_block_stride * col_block_arange[:, None] - + dense_col_block_stride * row_block_arange[None, :] - ) - - # Pointers are set to exact write-to locations - output_ptrs = ( - output_ptr - + output_batch_stride * batch_idx - + output_tiled_row_stride * row_idx - + output_tiled_col_stride * col_block_pid - + output_row_block_stride * row_block_arange[:, None] - + output_col_block_stride * row_block_arange[None, :] - ) - - output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32) - col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride - for _ in range(row_idx_nnz): - values_block = tl.load(values_block_ptrs) - - # find which row of dense needs to get loaded - # for multiplication with values_block. - dense_row_idx = tl.load(col_index_nnz_ptr) - dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) - - # do block mm - output_acc_block += tl.dot(values_block, dense_block) - - # move val/col_index ptrs to the next block in the row - values_block_ptrs += values_nnz_stride - col_index_nnz_ptr += col_indices_stride - - # write back the result - tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) - - - def _run_sparse_rowspace_kernel( - blocksize, values, crow_indices, col_indices, dense, output, max_grid - ): - # Compute a vector of non-zero elements numbers per each row. - # We want to ultimately iterate over non-zero rows. - nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1] - - # Compute indices of non-zero counts. - # batch_idx maps to a broadcasted batch index, while - # row_idx tracks non-zero rows of the sparse argument - # and rows of the output that get modified. - batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True) - - # Compress the vector of counts to hold only non-zero values. - nnz_per_row = nnz_per_row[batch_idx, row_idx] - # Compute cumulative counts which along with nnz_per_row - # are used to compute offsets into nnz values. - nnz_per_row_cumsum = nnz_per_row.cumsum(-1) - - n_nnz_block_rows = row_idx.size(-1) - n_block_cols = dense.size(-3) - max_n_nnz_block_rows, max_n_block_cols = max_grid[:2] - - for c_start in range(0, n_block_cols, max_n_block_cols): - c_dense, c_output = slicer( - -3, slice(c_start, c_start + max_n_block_cols), dense, output - ) - c_grid = min(n_block_cols - c_start, max_n_block_cols) - - for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows): - r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer( - 0, - slice(r_start, r_start + max_n_nnz_block_rows), - batch_idx, - row_idx, - nnz_per_row, - nnz_per_row_cumsum, - ) - r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows) - - _bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)]( - *blocksize, - r_batch_idx, - r_row_idx, - r_nnz_per_row, - r_nnz_per_row_cumsum, - col_indices, - *col_indices.stride(), - values, - *values.stride(), - c_dense, - *c_dense.stride(), - c_output, - *c_output.stride(), - GROUP_SIZE_ROW=4, - num_stages=4, - num_warps=4, - ) - - - def _run_dense_rowspace_kernel( - blocksize, values, crow_indices, col_indices, dense, output, max_grid - ): - # Launch kernel - n_batches = dense.size(0) - n_block_rows = crow_indices.size(-1) - 1 - n_block_cols = dense.size(-3) - max_n_block_rows, max_n_block_cols, max_n_batches = max_grid - - for b_start in range(0, n_batches, max_n_batches): - b_v, b_crow, b_col, b_d, b_o = slicer( - 0, - slice(b_start, b_start + max_n_batches), - values, - crow_indices, - col_indices, - dense, - output, - ) - b_grid = min(n_batches - b_start, max_n_batches) - - for c_start in range(0, n_block_cols, max_n_block_cols): - bc_d, bc_o = slicer( - -3, slice(c_start, c_start + max_n_block_cols), b_d, b_o - ) - c_grid = min(n_block_cols - c_start, max_n_block_cols) - - for r_start in range(0, n_block_rows, max_n_block_rows): - r_slice = slice(r_start, r_start + max_n_block_rows) - br_crow = next(slicer(-1, r_slice, b_crow)) - brc_o = next(slicer(-4, r_slice, bc_o)) - r_grid = min(n_block_rows - r_start, max_n_block_rows) - - _bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)]( - *blocksize, - b_v, - *b_v.stride(), - br_crow, - *br_crow.stride(), - b_col, - *b_col.stride(), - bc_d, - *bc_d.stride(), - brc_o, - *brc_o.stride(), - GROUP_SIZE_ROW=4, - num_stages=4, - num_warps=4, - ) - - - def bsr_dense_mm( - bsr: torch.Tensor, - dense: torch.Tensor, - *, - skip_checks: bool = False, - is_sparse_rowspace_mode: Optional[bool] = None, - max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, - out: Optional[torch.Tensor] = None, - ): - m, kl = bsr.shape[-2:] - kr, n = dense.shape[-2:] - - def check(cond, msg): - if not cond: - raise ValueError(msg) - - if not skip_checks: - check( - bsr.layout == torch.sparse_bsr, - "bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.", - ) - - check( - bsr.device == dense.device and bsr.device.type == "cuda", - "bsr_dense_mm(): all inputs are expected to be on the same GPU device.", - ) - - check( - bsr.dtype == dense.dtype - and bsr.dtype in (torch.half, torch.bfloat16, torch.float), - "bsr_dense_mm(): all inputs are expected to be of the same dtype " - "and one of (half, bfloat16, float32), " - f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.", - ) - - check( - bsr.dim() >= 2 and dense.dim() >= 2, - "bsr_dense_mm(): all inputs are expected to be at least 2D, " - f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.", - ) - - check( - kl == kr, - "bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, " - f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.", - ) - - row_block = bsr.values().shape[-2] - check( - not n % row_block, - f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by " - f"blocksize[0] == {row_block}.", - ) - - # Required to undo the fake batch dimension insertion. - original_batch_dims_broadcasted = torch.broadcast_shapes( - bsr.shape[:-2], dense.shape[:-2] - ) - - if out is not None and not skip_checks: - expected_out_shape = original_batch_dims_broadcasted + (m, n) - check( - out.shape == expected_out_shape, - "bsr_dense_mm(): `out` argument has wrong shape, " - f"expected {expected_out_shape}, but got {out.shape}.", - ) - check( - out.is_contiguous() or out.transpose(-2, -1).is_contiguous(), - "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, " - "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) " - "should be True.", - ) - - # Short circuit if lhs is zero - if bsr._nnz() == 0: - return dense.new_zeros(original_batch_dims_broadcasted + (m, n)) - - # TODO: insert switch - if is_sparse_rowspace_mode is None: - is_sparse_rowspace_mode = False - - # Introduce fake batch dimension if not present for convenience. - def unsqueeze_batch_dim(t, n_non_batch_dims): - if t.dim() > n_non_batch_dims: - return t - else: - return t.unsqueeze(0) - - def make_triton_contiguous(t): - # Triton does not distinguish between row- and col-majorness - # and will be fast as long as there is a contiguous dimension. - if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()): - return t.contiguous() - else: - return t - - crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1) - col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1) - values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3)) - dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2)) - nnz = values.shape[-3] - blocksize = values.shape[-2:] - - # Compute broadcasted batch dimension - bsr_batch_dims = values.shape[:-3] - dense_batch_dims = dense.shape[:-2] - batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims) - - # Allocate out - if out is None: - out = dense.new_zeros(batch_dims_broadcasted + (m, n)) - - # Broadcast batch dimensions and squash - def batch_broadcast_and_squash(t, batch_dims, invariant_dims): - return t.broadcast_to(batch_dims + invariant_dims).flatten( - 0, len(batch_dims) - 1 - ) - - crow_indices = batch_broadcast_and_squash( - crow_indices, batch_dims_broadcasted, (-1,) - ) - - if is_sparse_rowspace_mode: - # Flatten batch dimension with nnz dimension - # as required by the sparse rowspace kernel. - col_indices = batch_broadcast_and_squash( - col_indices, batch_dims_broadcasted + (-1,), () - ) - values = batch_broadcast_and_squash( - values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:] - ) - else: - col_indices = batch_broadcast_and_squash( - col_indices, batch_dims_broadcasted, (-1,) - ) - values = batch_broadcast_and_squash( - values, batch_dims_broadcasted, values.shape[-3:] - ) - - dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:]) - - # NOTE: out is contiguous, so batch_broadcast_and_squash will create a view - out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:]) - - # NOTE: this function will ALWAYS create a view - def tile_to_blocksize(t, blocksize): - *rest, m, n = t.shape - new_shape = rest + [ - m // blocksize[0], - blocksize[0], - n // blocksize[1], - blocksize[1], - ] - return t.reshape(new_shape).transpose(-3, -2) - - # "Blockify" the row dimension of dense with blocksize[1] - # since dense is on the rhs of matmul - dense = tile_to_blocksize(dense, blocksize[::-1]) - # "Blockify" the row dimension of out with blocksize[0] - # which is inherited from the bsr input. - # NOTE: tile_to_blocksize will create a view. - # NOTE: out.blocksize[-1] == dense.blocksize[-1], - # so it could be any value in [1, dense.shape[-1]). - # We need to probably use the largest possible blocksize - # so that it fits into SRAM. - out = tile_to_blocksize(out, (blocksize[0], blocksize[0])) - - # Launch kernel - if is_sparse_rowspace_mode: - kernel = _run_sparse_rowspace_kernel - else: - kernel = _run_dense_rowspace_kernel - - # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) - cuda_max_grid = (2147483647, 65535, 65535) - if max_grid is None: - max_grid = cuda_max_grid - else: - - def valid_grid_dim(g, mg): - if g is None: - return mg - else: - # grid must be at least 1 and no greater than mg - return max(1, min(g, mg)) - - max_grid = tuple( - valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid) - ) # type: ignore[assignment] - - kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid) - - # Block dims need to rejoin with the corresponding block dimensions - # prior to reshape so that blocks do not end up being transposed. - # NB: type checker is not able to narrow Optional[Tensor] to tensor by this point - return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr] -else: - bsr_dense_mm = None # type: ignore[assignment] - - -if __name__ == "__main__": - from torch._inductor.utils import has_triton - - if has_triton(): - torch.manual_seed(13) - dtype = torch.float32 - p = 0.5 - mask_size = (8, 8) - block_size = (64, 64) - size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1]) - - n_exp = 512 - diff = torch.ones(n_exp, device="cuda", dtype=torch.float32) - for i in range(n_exp): - mask = torch.rand(*mask_size, device="cuda") < p - x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10 - x = ( - (mask[:, :, None, None] * x) - .transpose(-3, -2) - .reshape(*size) - .to_sparse_bsr(*block_size) - ) - y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10 - res_dense = x.to_dense() @ y - res = bsr_dense_mm(x, y) - diff[i] = (res - res_dense).abs().max() - print(f"mean: {diff.mean()}, std: {diff.std()}") - print(f"max diff: {diff.max()}")