mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title. Additionally we also introduce support for: - Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation). - Batch support with broadcasting for either of the arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
befe815466
commit
7f256fff77
|
|
@ -6446,6 +6446,12 @@
|
||||||
SparseCPU: s_addmm_sparse_dense_cpu_
|
SparseCPU: s_addmm_sparse_dense_cpu_
|
||||||
SparseCUDA: s_addmm_sparse_dense_cuda_
|
SparseCUDA: s_addmm_sparse_dense_cuda_
|
||||||
|
|
||||||
|
- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor
|
||||||
|
variants: function
|
||||||
|
dispatch:
|
||||||
|
CPU: triton_bsr_dense_mm
|
||||||
|
autogen: _triton_bsr_dense_mm.out
|
||||||
|
|
||||||
- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
|
- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@
|
||||||
#include <ATen/native/sparse/SparseBlasImpl.h>
|
#include <ATen/native/sparse/SparseBlasImpl.h>
|
||||||
#include <ATen/SparseCsrTensorUtils.h>
|
#include <ATen/SparseCsrTensorUtils.h>
|
||||||
|
|
||||||
|
// Required for checking whether Triton kernels are available
|
||||||
|
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
|
|
@ -12,6 +16,7 @@
|
||||||
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
||||||
#include <ATen/ops/empty_like.h>
|
#include <ATen/ops/empty_like.h>
|
||||||
#include <ATen/ops/zeros.h>
|
#include <ATen/ops/zeros.h>
|
||||||
|
#include <ATen/ops/_triton_bsr_dense_mm.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
@ -70,6 +75,31 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s
|
||||||
blocksize = {values.size(-2), values.size(-1)};
|
blocksize = {values.size(-2), values.size(-1)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No stable support for ROCM in Triton yet.
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Triton works only with blocksizes which are powers of 2.
|
||||||
|
const auto is_power_of_2 = [](int64_t v) -> bool {
|
||||||
|
return !(v & (v - 1));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dtype and blocksize checks for potential Triton usage.
|
||||||
|
if ((strided.scalar_type() == ScalarType::Half
|
||||||
|
|| strided.scalar_type() == ScalarType::BFloat16)
|
||||||
|
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
|
||||||
|
&& (blocksize[0] >= 16) && (blocksize[1] >= 16)
|
||||||
|
// lhs is retiled to (b0, b1) while rhs is to (b1, b0),
|
||||||
|
// so the result is tiled to (b0, b0) and we need to make
|
||||||
|
// sure that dense.size(-1) is divisible by b0.
|
||||||
|
&& n % blocksize[0] == 0) {
|
||||||
|
const auto triton_kernel = c10::Dispatcher::singleton()
|
||||||
|
.findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
|
||||||
|
// Call Triton only if dispatch key was overwritten.
|
||||||
|
if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
|
||||||
|
return at::_triton_bsr_dense_mm_out(result, compressed, strided);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
|
// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
|
||||||
// NOTE: this function ALWAYS creates a view upon successful execution.
|
// NOTE: this function ALWAYS creates a view upon successful execution.
|
||||||
const auto tile_tensor = [compressed_layout](
|
const auto tile_tensor = [compressed_layout](
|
||||||
|
|
|
||||||
|
|
@ -1292,5 +1292,12 @@ Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, boo
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor triton_bsr_dense_mm(
|
||||||
|
const Tensor& bsr,
|
||||||
|
const Tensor& dense) {
|
||||||
|
TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python.");
|
||||||
|
return Tensor {};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -274,6 +274,5 @@ Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
3
mypy.ini
3
mypy.ini
|
|
@ -188,6 +188,9 @@ ignore_errors = True
|
||||||
# Third party dependencies that don't have types.
|
# Third party dependencies that don't have types.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
[mypy-triton.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-tensorflow.*]
|
[mypy-tensorflow.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from torch.testing._internal.common_dtype import (
|
||||||
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
|
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
|
||||||
all_types_and_complex, floating_and_complex_types_and
|
all_types_and_complex, floating_and_complex_types_and
|
||||||
)
|
)
|
||||||
|
from torch._inductor.utils import has_triton
|
||||||
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
|
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
|
||||||
|
|
||||||
if TEST_SCIPY:
|
if TEST_SCIPY:
|
||||||
|
|
@ -1464,6 +1465,63 @@ class TestSparseCSR(TestCase):
|
||||||
self.assertEqual(actual, out)
|
self.assertEqual(actual, out)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@parametrize("block_size", [16, 32, 64])
|
||||||
|
@parametrize("index_dtype", [torch.int32, torch.int64])
|
||||||
|
@unittest.skipIf(not has_triton(), "Triton is not available")
|
||||||
|
@skipCUDAIfRocm
|
||||||
|
@onlyCUDA
|
||||||
|
@dtypes(torch.half, torch.bfloat16)
|
||||||
|
@dtypesIfCUDA(*[torch.half] if SM53OrLater else [],
|
||||||
|
*[torch.bfloat16] if SM80OrLater else [])
|
||||||
|
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
|
||||||
|
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
|
||||||
|
|
||||||
|
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
|
||||||
|
batches = [(), (2,)]
|
||||||
|
size = [128, 256, 0]
|
||||||
|
|
||||||
|
# Whether to make inputs orthogonal so that the product is zero
|
||||||
|
make_orthogonal = [True, False]
|
||||||
|
|
||||||
|
for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
|
||||||
|
bsr = tensor(bs + (m, k))
|
||||||
|
# NOTE: do not get confused, it will be transposed
|
||||||
|
dense = tensor(bd + (n, k))
|
||||||
|
|
||||||
|
if is_ortho:
|
||||||
|
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
|
||||||
|
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
|
||||||
|
|
||||||
|
bsr = bsr.to_sparse_bsr(block_size)
|
||||||
|
|
||||||
|
if bsr.dim() == 2:
|
||||||
|
# Test against linear to check dispatch.
|
||||||
|
res_tri = torch.nn.functional.linear(dense, bsr)
|
||||||
|
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
|
||||||
|
else:
|
||||||
|
# Otherwise check correctness against bmm
|
||||||
|
# since nn.linear does not support bsr.dim() > 2.
|
||||||
|
res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
|
||||||
|
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
|
||||||
|
self.assertEqual(res_tri, res_dense)
|
||||||
|
|
||||||
|
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
|
||||||
|
# check whether bsr_dense_mm handles different grid sizes
|
||||||
|
# None means max possible grid size which is CUDA-dependent.
|
||||||
|
grid_size = (None, 2, 4)
|
||||||
|
grid_gen = itertools.product(grid_size, repeat=3)
|
||||||
|
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
|
||||||
|
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
|
||||||
|
bsr,
|
||||||
|
dense.transpose(-2, -1),
|
||||||
|
max_grid=grid,
|
||||||
|
is_sparse_rowspace_mode=is_sparse_rowspace
|
||||||
|
)
|
||||||
|
self.assertEqual(res_tri, res_dense)
|
||||||
|
|
||||||
# TODO: block_size 1 is broken
|
# TODO: block_size 1 is broken
|
||||||
@parametrize("block_size", [2, 3])
|
@parametrize("block_size", [2, 3])
|
||||||
@parametrize("index_dtype", [torch.int32, torch.int64])
|
@parametrize("index_dtype", [torch.int32, torch.int64])
|
||||||
|
|
|
||||||
|
|
@ -1328,3 +1328,7 @@ import torch.fx.experimental.symbolic_shapes
|
||||||
|
|
||||||
from torch import func as func
|
from torch import func as func
|
||||||
from torch.func import vmap
|
from torch.func import vmap
|
||||||
|
|
||||||
|
# dynamic registration of sparse triton kernels
|
||||||
|
from torch.sparse import _register_impls
|
||||||
|
_register_impls(torch.library.Library("aten", "IMPL"))
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from typing import Optional, Tuple, List, Union
|
||||||
import torch
|
import torch
|
||||||
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda import _lazy_call
|
||||||
|
from torch._inductor.cuda_properties import get_device_capability
|
||||||
|
|
||||||
# A workaround to support both TorchScript and MyPy:
|
# A workaround to support both TorchScript and MyPy:
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
@ -462,3 +464,30 @@ See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more informat
|
||||||
return mth(*args, **kwargs)
|
return mth(*args, **kwargs)
|
||||||
|
|
||||||
return test_mth
|
return test_mth
|
||||||
|
|
||||||
|
# Triton registrations
|
||||||
|
def _has_triton():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
|
||||||
|
return triton is not None and get_device_capability() >= (7, 0)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _register_impls(lib):
|
||||||
|
"""This function is called from torch/__init__.py to do any dynamic registrations. """
|
||||||
|
|
||||||
|
|
||||||
|
def register_sparse_cuda_impls(lib=lib):
|
||||||
|
from ._triton_ops import bsr_dense_mm
|
||||||
|
|
||||||
|
if bsr_dense_mm is not None:
|
||||||
|
lib.impl("aten::_triton_bsr_dense_mm",
|
||||||
|
lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs), "SparseCsrCUDA")
|
||||||
|
|
||||||
|
# This code is evaluated on import torch and therefore cannot force initialization of the cuda rt
|
||||||
|
# We must schedule the registration to occur lazily.
|
||||||
|
_lazy_call(register_sparse_cuda_impls)
|
||||||
|
|
|
||||||
608
torch/sparse/_triton_ops.py
Normal file
608
torch/sparse/_triton_ops.py
Normal file
|
|
@ -0,0 +1,608 @@
|
||||||
|
import torch
|
||||||
|
from torch._inductor.cuda_properties import get_device_capability
|
||||||
|
|
||||||
|
def _has_triton():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
|
||||||
|
return triton is not None and get_device_capability() >= (7, 0)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def compressed_indices_to_plain_indices(cidx, pidx):
|
||||||
|
nnz = pidx.shape[-1]
|
||||||
|
cdim = cidx.shape[-1] - 1
|
||||||
|
batch_numel = cidx.shape[0]
|
||||||
|
batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
|
||||||
|
:, None
|
||||||
|
]
|
||||||
|
|
||||||
|
cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
|
||||||
|
cidx_linear = torch.empty(
|
||||||
|
(batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
|
||||||
|
)
|
||||||
|
cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
|
||||||
|
cidx_linear[-1] = nnz * batch_numel
|
||||||
|
|
||||||
|
idx_linear = torch._convert_indices_from_csr_to_coo(
|
||||||
|
cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
|
||||||
|
).select(0, 0)
|
||||||
|
|
||||||
|
return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)
|
||||||
|
|
||||||
|
|
||||||
|
def slicer(dim, slice_range, *tensors):
|
||||||
|
for t in tensors:
|
||||||
|
slices = [slice(None)] * t.dim()
|
||||||
|
slices[dim] = slice_range
|
||||||
|
yield t[slices]
|
||||||
|
|
||||||
|
if _has_triton():
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _bsr_strided_dense_rowspace_kernel(
|
||||||
|
BLOCKSIZE_ROW: tl.constexpr,
|
||||||
|
BLOCKSIZE_COL: tl.constexpr,
|
||||||
|
# values prologue
|
||||||
|
values_ptr,
|
||||||
|
values_batch_stride,
|
||||||
|
values_nnz_stride,
|
||||||
|
values_row_block_stride,
|
||||||
|
values_col_block_stride,
|
||||||
|
# values epilogue
|
||||||
|
# crow_indices prologue
|
||||||
|
crow_indices_ptr,
|
||||||
|
crow_indices_batch_stride,
|
||||||
|
crow_indices_stride,
|
||||||
|
# crow_indices epilogue
|
||||||
|
# col_indices prologue
|
||||||
|
col_indices_ptr,
|
||||||
|
col_indices_batch_stride,
|
||||||
|
col_indices_stride,
|
||||||
|
# col_indices epilogue
|
||||||
|
# dense prologue
|
||||||
|
dense_ptr,
|
||||||
|
dense_batch_stride,
|
||||||
|
dense_tiled_row_stride,
|
||||||
|
dense_tiled_col_stride,
|
||||||
|
dense_row_block_stride,
|
||||||
|
dense_col_block_stride,
|
||||||
|
# dense epilogue
|
||||||
|
# output prologue
|
||||||
|
output_ptr,
|
||||||
|
output_batch_stride,
|
||||||
|
output_tiled_row_stride,
|
||||||
|
output_tiled_col_stride,
|
||||||
|
output_row_block_stride,
|
||||||
|
output_col_block_stride,
|
||||||
|
# output epilogue
|
||||||
|
GROUP_SIZE_ROW: tl.constexpr,
|
||||||
|
):
|
||||||
|
batch_pid = tl.program_id(axis=2)
|
||||||
|
row_block_pid = tl.program_id(axis=0)
|
||||||
|
col_block_pid = tl.program_id(axis=1)
|
||||||
|
n_block_rows = tl.num_programs(axis=0)
|
||||||
|
n_block_cols = tl.num_programs(axis=1)
|
||||||
|
|
||||||
|
row_block_pid, col_block_pid = tl.swizzle2d(
|
||||||
|
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
|
||||||
|
)
|
||||||
|
|
||||||
|
crow_indices_offset_ptr = (
|
||||||
|
crow_indices_ptr
|
||||||
|
+ crow_indices_batch_stride * batch_pid
|
||||||
|
+ crow_indices_stride * row_block_pid
|
||||||
|
)
|
||||||
|
nnz_offset = tl.load(crow_indices_offset_ptr)
|
||||||
|
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
|
||||||
|
|
||||||
|
# Compute nnz for the row with number row_block_pid.
|
||||||
|
# If it is zero, skip the row.
|
||||||
|
row_nnz = nnz_offset_next - nnz_offset
|
||||||
|
if row_nnz == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
|
||||||
|
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
|
||||||
|
|
||||||
|
# Pointers are set to the first block of the current row.
|
||||||
|
values_block_ptrs = (
|
||||||
|
values_ptr
|
||||||
|
+ values_batch_stride * batch_pid
|
||||||
|
+ values_nnz_stride * nnz_offset
|
||||||
|
+ values_row_block_stride * row_block_arange[:, None]
|
||||||
|
+ values_col_block_stride * col_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: dense is advanced into all dimensions but the tiled row one.
|
||||||
|
# That will be advanced in the loop according to values in col_indices.
|
||||||
|
dense_block_ptrs = (
|
||||||
|
dense_ptr
|
||||||
|
+ dense_batch_stride * batch_pid
|
||||||
|
+ dense_tiled_col_stride * col_block_pid
|
||||||
|
+ dense_row_block_stride * col_block_arange[:, None]
|
||||||
|
+ dense_col_block_stride * row_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pointers are set to exact write-to locations
|
||||||
|
output_ptrs = (
|
||||||
|
output_ptr
|
||||||
|
+ output_batch_stride * batch_pid
|
||||||
|
+ output_tiled_row_stride * row_block_pid
|
||||||
|
+ output_tiled_col_stride * col_block_pid
|
||||||
|
+ output_row_block_stride * row_block_arange[:, None]
|
||||||
|
+ output_col_block_stride * row_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set pointer to the first nonzero element in the current row
|
||||||
|
col_index_nnz_ptr = (
|
||||||
|
col_indices_ptr
|
||||||
|
+ col_indices_batch_stride * batch_pid
|
||||||
|
+ col_indices_stride * nnz_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
|
||||||
|
for _ in range(row_nnz):
|
||||||
|
values_block = tl.load(values_block_ptrs)
|
||||||
|
|
||||||
|
# find which row of dense needs to get loaded
|
||||||
|
# for multiplication with values_block.
|
||||||
|
dense_row_idx = tl.load(col_index_nnz_ptr)
|
||||||
|
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
|
||||||
|
|
||||||
|
# do block mm
|
||||||
|
output_acc_block += tl.dot(values_block, dense_block)
|
||||||
|
|
||||||
|
# move val/col_index ptrs to the next block in the row
|
||||||
|
values_block_ptrs += values_nnz_stride
|
||||||
|
col_index_nnz_ptr += col_indices_stride
|
||||||
|
|
||||||
|
# write back the result
|
||||||
|
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _bsr_strided_sparse_rowspace_kernel(
|
||||||
|
BLOCKSIZE_ROW: tl.constexpr,
|
||||||
|
BLOCKSIZE_COL: tl.constexpr,
|
||||||
|
batch_idx_ptr,
|
||||||
|
row_idx_ptr,
|
||||||
|
nnz_per_row_ptr,
|
||||||
|
nnz_per_row_cumsum_ptr,
|
||||||
|
col_indices_ptr,
|
||||||
|
col_indices_stride,
|
||||||
|
# values prologue
|
||||||
|
values_ptr,
|
||||||
|
values_nnz_stride,
|
||||||
|
values_row_block_stride,
|
||||||
|
values_col_block_stride,
|
||||||
|
# values epilogue
|
||||||
|
# dense prologue
|
||||||
|
dense_ptr,
|
||||||
|
dense_batch_stride,
|
||||||
|
dense_tiled_row_stride,
|
||||||
|
dense_tiled_col_stride,
|
||||||
|
dense_row_block_stride,
|
||||||
|
dense_col_block_stride,
|
||||||
|
# dense epilogue
|
||||||
|
# output prologue
|
||||||
|
output_ptr,
|
||||||
|
output_batch_stride,
|
||||||
|
output_tiled_row_stride,
|
||||||
|
output_tiled_col_stride,
|
||||||
|
output_row_block_stride,
|
||||||
|
output_col_block_stride,
|
||||||
|
# output epilogue
|
||||||
|
GROUP_SIZE_ROW: tl.constexpr,
|
||||||
|
):
|
||||||
|
row_block_pid = tl.program_id(axis=0)
|
||||||
|
col_block_pid = tl.program_id(axis=1)
|
||||||
|
n_block_rows = tl.num_programs(axis=0)
|
||||||
|
n_block_cols = tl.num_programs(axis=1)
|
||||||
|
|
||||||
|
row_block_pid, col_block_pid = tl.swizzle2d(
|
||||||
|
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_idx = tl.load(batch_idx_ptr + row_block_pid)
|
||||||
|
row_idx = tl.load(row_idx_ptr + row_block_pid)
|
||||||
|
row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
|
||||||
|
row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
|
||||||
|
row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz
|
||||||
|
|
||||||
|
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
|
||||||
|
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
|
||||||
|
|
||||||
|
# Pointers are set to the first block of the current row.
|
||||||
|
values_block_ptrs = (
|
||||||
|
values_ptr
|
||||||
|
+ values_nnz_stride * row_idx_nnz_offset
|
||||||
|
+ values_row_block_stride * row_block_arange[:, None]
|
||||||
|
+ values_col_block_stride * col_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: dense is advanced into all dimensions but the tiled row one.
|
||||||
|
# That will be advanced in the loop according to values in col_indices.
|
||||||
|
dense_block_ptrs = (
|
||||||
|
dense_ptr
|
||||||
|
+ dense_batch_stride * batch_idx
|
||||||
|
+ dense_tiled_col_stride * col_block_pid
|
||||||
|
+ dense_row_block_stride * col_block_arange[:, None]
|
||||||
|
+ dense_col_block_stride * row_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pointers are set to exact write-to locations
|
||||||
|
output_ptrs = (
|
||||||
|
output_ptr
|
||||||
|
+ output_batch_stride * batch_idx
|
||||||
|
+ output_tiled_row_stride * row_idx
|
||||||
|
+ output_tiled_col_stride * col_block_pid
|
||||||
|
+ output_row_block_stride * row_block_arange[:, None]
|
||||||
|
+ output_col_block_stride * row_block_arange[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
|
||||||
|
col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
|
||||||
|
for _ in range(row_idx_nnz):
|
||||||
|
values_block = tl.load(values_block_ptrs)
|
||||||
|
|
||||||
|
# find which row of dense needs to get loaded
|
||||||
|
# for multiplication with values_block.
|
||||||
|
dense_row_idx = tl.load(col_index_nnz_ptr)
|
||||||
|
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
|
||||||
|
|
||||||
|
# do block mm
|
||||||
|
output_acc_block += tl.dot(values_block, dense_block)
|
||||||
|
|
||||||
|
# move val/col_index ptrs to the next block in the row
|
||||||
|
values_block_ptrs += values_nnz_stride
|
||||||
|
col_index_nnz_ptr += col_indices_stride
|
||||||
|
|
||||||
|
# write back the result
|
||||||
|
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
|
||||||
|
def _run_sparse_rowspace_kernel(
|
||||||
|
blocksize, values, crow_indices, col_indices, dense, output, max_grid
|
||||||
|
):
|
||||||
|
# Compute a vector of non-zero elements numbers per each row.
|
||||||
|
# We want to ultimately iterate over non-zero rows.
|
||||||
|
nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]
|
||||||
|
|
||||||
|
# Compute indices of non-zero counts.
|
||||||
|
# batch_idx maps to a broadcasted batch index, while
|
||||||
|
# row_idx tracks non-zero rows of the sparse argument
|
||||||
|
# and rows of the output that get modified.
|
||||||
|
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)
|
||||||
|
|
||||||
|
# Compress the vector of counts to hold only non-zero values.
|
||||||
|
nnz_per_row = nnz_per_row[batch_idx, row_idx]
|
||||||
|
# Compute cumulative counts which along with nnz_per_row
|
||||||
|
# are used to compute offsets into nnz values.
|
||||||
|
nnz_per_row_cumsum = nnz_per_row.cumsum(-1)
|
||||||
|
|
||||||
|
n_nnz_block_rows = row_idx.size(-1)
|
||||||
|
n_block_cols = dense.size(-3)
|
||||||
|
max_n_nnz_block_rows, max_n_block_cols = max_grid[:2]
|
||||||
|
|
||||||
|
for c_start in range(0, n_block_cols, max_n_block_cols):
|
||||||
|
c_dense, c_output = slicer(
|
||||||
|
-3, slice(c_start, c_start + max_n_block_cols), dense, output
|
||||||
|
)
|
||||||
|
c_grid = min(n_block_cols - c_start, max_n_block_cols)
|
||||||
|
|
||||||
|
for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows):
|
||||||
|
r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer(
|
||||||
|
0,
|
||||||
|
slice(r_start, r_start + max_n_nnz_block_rows),
|
||||||
|
batch_idx,
|
||||||
|
row_idx,
|
||||||
|
nnz_per_row,
|
||||||
|
nnz_per_row_cumsum,
|
||||||
|
)
|
||||||
|
r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows)
|
||||||
|
|
||||||
|
_bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)](
|
||||||
|
*blocksize,
|
||||||
|
r_batch_idx,
|
||||||
|
r_row_idx,
|
||||||
|
r_nnz_per_row,
|
||||||
|
r_nnz_per_row_cumsum,
|
||||||
|
col_indices,
|
||||||
|
*col_indices.stride(),
|
||||||
|
values,
|
||||||
|
*values.stride(),
|
||||||
|
c_dense,
|
||||||
|
*c_dense.stride(),
|
||||||
|
c_output,
|
||||||
|
*c_output.stride(),
|
||||||
|
GROUP_SIZE_ROW=4,
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dense_rowspace_kernel(
|
||||||
|
blocksize, values, crow_indices, col_indices, dense, output, max_grid
|
||||||
|
):
|
||||||
|
# Launch kernel
|
||||||
|
n_batches = dense.size(0)
|
||||||
|
n_block_rows = crow_indices.size(-1) - 1
|
||||||
|
n_block_cols = dense.size(-3)
|
||||||
|
max_n_block_rows, max_n_block_cols, max_n_batches = max_grid
|
||||||
|
|
||||||
|
for b_start in range(0, n_batches, max_n_batches):
|
||||||
|
b_v, b_crow, b_col, b_d, b_o = slicer(
|
||||||
|
0,
|
||||||
|
slice(b_start, b_start + max_n_batches),
|
||||||
|
values,
|
||||||
|
crow_indices,
|
||||||
|
col_indices,
|
||||||
|
dense,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
b_grid = min(n_batches - b_start, max_n_batches)
|
||||||
|
|
||||||
|
for c_start in range(0, n_block_cols, max_n_block_cols):
|
||||||
|
bc_d, bc_o = slicer(
|
||||||
|
-3, slice(c_start, c_start + max_n_block_cols), b_d, b_o
|
||||||
|
)
|
||||||
|
c_grid = min(n_block_cols - c_start, max_n_block_cols)
|
||||||
|
|
||||||
|
for r_start in range(0, n_block_rows, max_n_block_rows):
|
||||||
|
r_slice = slice(r_start, r_start + max_n_block_rows)
|
||||||
|
br_crow = next(slicer(-1, r_slice, b_crow))
|
||||||
|
brc_o = next(slicer(-4, r_slice, bc_o))
|
||||||
|
r_grid = min(n_block_rows - r_start, max_n_block_rows)
|
||||||
|
|
||||||
|
_bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)](
|
||||||
|
*blocksize,
|
||||||
|
b_v,
|
||||||
|
*b_v.stride(),
|
||||||
|
br_crow,
|
||||||
|
*br_crow.stride(),
|
||||||
|
b_col,
|
||||||
|
*b_col.stride(),
|
||||||
|
bc_d,
|
||||||
|
*bc_d.stride(),
|
||||||
|
brc_o,
|
||||||
|
*brc_o.stride(),
|
||||||
|
GROUP_SIZE_ROW=4,
|
||||||
|
num_stages=4,
|
||||||
|
num_warps=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bsr_dense_mm(
|
||||||
|
bsr: torch.Tensor,
|
||||||
|
dense: torch.Tensor,
|
||||||
|
*,
|
||||||
|
skip_checks: bool = False,
|
||||||
|
is_sparse_rowspace_mode: Optional[bool] = None,
|
||||||
|
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
m, kl = bsr.shape[-2:]
|
||||||
|
kr, n = dense.shape[-2:]
|
||||||
|
|
||||||
|
def check(cond, msg):
|
||||||
|
if not cond:
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if not skip_checks:
|
||||||
|
check(
|
||||||
|
bsr.layout == torch.sparse_bsr,
|
||||||
|
"bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.",
|
||||||
|
)
|
||||||
|
|
||||||
|
check(
|
||||||
|
bsr.device == dense.device and bsr.device.type == "cuda",
|
||||||
|
"bsr_dense_mm(): all inputs are expected to be on the same GPU device.",
|
||||||
|
)
|
||||||
|
|
||||||
|
check(
|
||||||
|
bsr.dtype == dense.dtype
|
||||||
|
and bsr.dtype in (torch.half, torch.bfloat16, torch.float),
|
||||||
|
"bsr_dense_mm(): all inputs are expected to be of the same dtype "
|
||||||
|
"and one of (half, bfloat16, float32), "
|
||||||
|
f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
check(
|
||||||
|
bsr.dim() >= 2 and dense.dim() >= 2,
|
||||||
|
"bsr_dense_mm(): all inputs are expected to be at least 2D, "
|
||||||
|
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
check(
|
||||||
|
kl == kr,
|
||||||
|
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
|
||||||
|
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
row_block = bsr.values().shape[-2]
|
||||||
|
check(
|
||||||
|
not n % row_block,
|
||||||
|
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
|
||||||
|
f"blocksize[0] == {row_block}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Required to undo the fake batch dimension insertion.
|
||||||
|
original_batch_dims_broadcasted = torch.broadcast_shapes(
|
||||||
|
bsr.shape[:-2], dense.shape[:-2]
|
||||||
|
)
|
||||||
|
|
||||||
|
if out is not None and not skip_checks:
|
||||||
|
expected_out_shape = original_batch_dims_broadcasted + (m, n)
|
||||||
|
check(
|
||||||
|
out.shape == expected_out_shape,
|
||||||
|
"bsr_dense_mm(): `out` argument has wrong shape, "
|
||||||
|
f"expected {expected_out_shape}, but got {out.shape}.",
|
||||||
|
)
|
||||||
|
check(
|
||||||
|
out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
|
||||||
|
"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
|
||||||
|
"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
|
||||||
|
"should be True.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Short circuit if lhs is zero
|
||||||
|
if bsr._nnz() == 0:
|
||||||
|
return dense.new_zeros(original_batch_dims_broadcasted + (m, n))
|
||||||
|
|
||||||
|
# TODO: insert switch
|
||||||
|
if is_sparse_rowspace_mode is None:
|
||||||
|
is_sparse_rowspace_mode = False
|
||||||
|
|
||||||
|
# Introduce fake batch dimension if not present for convenience.
|
||||||
|
def unsqueeze_batch_dim(t, n_non_batch_dims):
|
||||||
|
if t.dim() > n_non_batch_dims:
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
return t.unsqueeze(0)
|
||||||
|
|
||||||
|
def make_triton_contiguous(t):
|
||||||
|
# Triton does not distinguish between row- and col-majorness
|
||||||
|
# and will be fast as long as there is a contiguous dimension.
|
||||||
|
if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()):
|
||||||
|
return t.contiguous()
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
|
crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1)
|
||||||
|
col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1)
|
||||||
|
values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3))
|
||||||
|
dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2))
|
||||||
|
nnz = values.shape[-3]
|
||||||
|
blocksize = values.shape[-2:]
|
||||||
|
|
||||||
|
# Compute broadcasted batch dimension
|
||||||
|
bsr_batch_dims = values.shape[:-3]
|
||||||
|
dense_batch_dims = dense.shape[:-2]
|
||||||
|
batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)
|
||||||
|
|
||||||
|
# Allocate out
|
||||||
|
if out is None:
|
||||||
|
out = dense.new_zeros(batch_dims_broadcasted + (m, n))
|
||||||
|
|
||||||
|
# Broadcast batch dimensions and squash
|
||||||
|
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
|
||||||
|
return t.broadcast_to(batch_dims + invariant_dims).flatten(
|
||||||
|
0, len(batch_dims) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
crow_indices = batch_broadcast_and_squash(
|
||||||
|
crow_indices, batch_dims_broadcasted, (-1,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sparse_rowspace_mode:
|
||||||
|
# Flatten batch dimension with nnz dimension
|
||||||
|
# as required by the sparse rowspace kernel.
|
||||||
|
col_indices = batch_broadcast_and_squash(
|
||||||
|
col_indices, batch_dims_broadcasted + (-1,), ()
|
||||||
|
)
|
||||||
|
values = batch_broadcast_and_squash(
|
||||||
|
values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
col_indices = batch_broadcast_and_squash(
|
||||||
|
col_indices, batch_dims_broadcasted, (-1,)
|
||||||
|
)
|
||||||
|
values = batch_broadcast_and_squash(
|
||||||
|
values, batch_dims_broadcasted, values.shape[-3:]
|
||||||
|
)
|
||||||
|
|
||||||
|
dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])
|
||||||
|
|
||||||
|
# NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
|
||||||
|
out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])
|
||||||
|
|
||||||
|
# NOTE: this function will ALWAYS create a view
|
||||||
|
def tile_to_blocksize(t, blocksize):
|
||||||
|
*rest, m, n = t.shape
|
||||||
|
new_shape = rest + [
|
||||||
|
m // blocksize[0],
|
||||||
|
blocksize[0],
|
||||||
|
n // blocksize[1],
|
||||||
|
blocksize[1],
|
||||||
|
]
|
||||||
|
return t.reshape(new_shape).transpose(-3, -2)
|
||||||
|
|
||||||
|
# "Blockify" the row dimension of dense with blocksize[1]
|
||||||
|
# since dense is on the rhs of matmul
|
||||||
|
dense = tile_to_blocksize(dense, blocksize[::-1])
|
||||||
|
# "Blockify" the row dimension of out with blocksize[0]
|
||||||
|
# which is inherited from the bsr input.
|
||||||
|
# NOTE: tile_to_blocksize will create a view.
|
||||||
|
# NOTE: out.blocksize[-1] == dense.blocksize[-1],
|
||||||
|
# so it could be any value in [1, dense.shape[-1]).
|
||||||
|
# We need to probably use the largest possible blocksize
|
||||||
|
# so that it fits into SRAM.
|
||||||
|
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
|
||||||
|
|
||||||
|
# Launch kernel
|
||||||
|
if is_sparse_rowspace_mode:
|
||||||
|
kernel = _run_sparse_rowspace_kernel
|
||||||
|
else:
|
||||||
|
kernel = _run_dense_rowspace_kernel
|
||||||
|
|
||||||
|
# cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
|
||||||
|
cuda_max_grid = (2147483647, 65535, 65535)
|
||||||
|
if max_grid is None:
|
||||||
|
max_grid = cuda_max_grid
|
||||||
|
else:
|
||||||
|
|
||||||
|
def valid_grid_dim(g, mg):
|
||||||
|
if g is None:
|
||||||
|
return mg
|
||||||
|
else:
|
||||||
|
# grid must be at least 1 and no greater than mg
|
||||||
|
return max(1, min(g, mg))
|
||||||
|
|
||||||
|
max_grid = tuple(
|
||||||
|
valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid)
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
|
kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
|
||||||
|
|
||||||
|
# Block dims need to rejoin with the corresponding block dimensions
|
||||||
|
# prior to reshape so that blocks do not end up being transposed.
|
||||||
|
# NB: type checker is not able to narrow Optional[Tensor] to tensor by this point
|
||||||
|
return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
bsr_dense_mm = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._inductor.utils import has_triton
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
torch.manual_seed(13)
|
||||||
|
dtype = torch.float32
|
||||||
|
p = 0.5
|
||||||
|
mask_size = (8, 8)
|
||||||
|
block_size = (64, 64)
|
||||||
|
size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])
|
||||||
|
|
||||||
|
n_exp = 512
|
||||||
|
diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
|
||||||
|
for i in range(n_exp):
|
||||||
|
mask = torch.rand(*mask_size, device="cuda") < p
|
||||||
|
x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
|
||||||
|
x = (
|
||||||
|
(mask[:, :, None, None] * x)
|
||||||
|
.transpose(-3, -2)
|
||||||
|
.reshape(*size)
|
||||||
|
.to_sparse_bsr(*block_size)
|
||||||
|
)
|
||||||
|
y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
|
||||||
|
res_dense = x.to_dense() @ y
|
||||||
|
res = bsr_dense_mm(x, y)
|
||||||
|
diff[i] = (res - res_dense).abs().max()
|
||||||
|
print(f"mean: {diff.mean()}, std: {diff.std()}")
|
||||||
|
print(f"max diff: {diff.max()}")
|
||||||
Loading…
Reference in New Issue
Block a user