Add scatter_mm and bsr_scatter_mm operations. (#110396)

This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110396
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson 2023-10-22 10:34:32 +00:00 committed by PyTorch MergeBot
parent 3b9246ba18
commit d4708a6da7
3 changed files with 1179 additions and 1 deletions

View File

@ -0,0 +1,388 @@
import torch
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
assert (
sparsity <= 1.0 and sparsity >= 0.0
), "sparsity should be a value between 0 and 1"
assert M % blocksize[0] == 0
assert N % blocksize[1] == 0
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device))
expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1]))
nonzero_indices = A.flatten().nonzero()
actual_nnz = nonzero_indices.shape[0]
if actual_nnz > expected_nnz:
selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz]
A.flatten()[nonzero_indices[selected_nonzeros]] = 0
elif actual_nnz < expected_nnz:
zero_indices = (A == 0).flatten().nonzero()
selected_zeros = torch.randperm(zero_indices.shape[0])[
: expected_nnz - actual_nnz
]
A.flatten()[zero_indices[selected_zeros]] = 1
A = torch.repeat_interleave(A, blocksize[0], dim=-2)
A = torch.repeat_interleave(A, blocksize[1], dim=-1)
return A
def _test_worker(test_func):
import triton
ms, ms_min, ms_max = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3)
return ms, tflops
def test_dense_dense_mm(x, y, **meta):
def test_func(x=x.to_dense(), y=y):
return torch.matmul(x, y)
return _test_worker(test_func)
def test_torch_matmul(x, y, **meta):
def test_func(x=x, y=y):
return torch.matmul(x, y)
return _test_worker(test_func)
def test_bsr_dense_mm(x, y, **meta):
from torch.sparse._triton_ops import bsr_dense_mm
def test_func(x=x, y=y):
return bsr_dense_mm(x, y)
return _test_worker(test_func)
def test_bsr_scatter_mm2(x, y, **meta):
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
indices_data = bsr_scatter_mm_indices_data(
x, y, indices_format="scatter_mm", **meta
)
def test_func(x=x, y=y):
return bsr_scatter_mm(x, y, indices_data=indices_data)
return _test_worker(test_func)
def test_bsr_scatter_mm6(x, y, **meta):
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
indices_data = bsr_scatter_mm_indices_data(
x, y, indices_format="bsr_strided_mm_compressed", **meta
)
def test_func(x=x, y=y):
return bsr_scatter_mm(x, y, indices_data=indices_data)
return _test_worker(test_func)
if __name__ == "__main__":
import argparse
import atexit
import itertools
import sys
import triton
from torch.testing import make_tensor
torch.manual_seed(0)
def integer_list(a):
return list(map(int, a.split(",")))
def float_list(a):
return list(map(float, a.split(",")))
def integer_or_float_list(a):
lst = []
for n in a.split(","):
if n.count(":") == 1:
start, end = map(int, n.split(":"))
lst.extend(range(start, end))
elif n.count(":") == 2:
start, end, step = map(int, n.split(":"))
lst.extend(range(start, end, step))
elif "." in n:
lst.append(float(n))
else:
lst.append(int(n))
return lst
parser = argparse.ArgumentParser(description="SpTritonOps")
parser.add_argument(
"--ops",
default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6",
type=str,
)
parser.add_argument("--b", default="0", type=int)
parser.add_argument("--m", default="1024", type=integer_list)
parser.add_argument("--k", default=None, type=integer_list)
parser.add_argument("--n", default=None, type=integer_list)
parser.add_argument("--bm", default="16", type=integer_list)
parser.add_argument("--bk", default=None, type=integer_list)
parser.add_argument("--tile_m", default=None, type=integer_list)
parser.add_argument("--tile_n", default=None, type=integer_list)
parser.add_argument("--split_n", default=None, type=integer_list)
parser.add_argument("--group_size", default=None, type=integer_list)
parser.add_argument("--num_warps", default=None, type=integer_list)
parser.add_argument("--num_stages", default=None, type=integer_list)
parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list)
parser.add_argument("--dtype", default="float16", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--repeat", default="1", type=int)
parser.add_argument("--outfile", default="stdout", type=str)
parser.add_argument("--star", default=False, action="store_true")
args = parser.parse_args()
if args.outfile == "stdout":
outfile = sys.stdout
elif args.outfile == "stderr":
outfile = sys.stderr
else:
outfile = open(args.outfile, "a")
ops = args.ops.split(",")
b = args.b
m_list = args.m or [1024]
n_list = args.n or [None]
k_list = args.k or [None]
bm_list = args.bm or [16]
bk_list = args.bk or [None]
split_n_list = args.split_n or [None]
tile_m_list = args.tile_m or [None]
tile_n_list = args.tile_n or [None]
group_size_list = args.group_size or [None]
num_warps_list = args.num_warps or [None]
num_stages_list = args.num_stages or [None]
sparsity_list = args.sparsity or [0.5]
dtype = getattr(torch, args.dtype)
if args.star > 0:
import torch.sparse._triton_ops
assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == {
1
}
m = m_list[0]
n = n_list[0] or m
k = k_list[0] or m
bm = bm_list[0]
bk = bk_list[0] or bm
meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk)
assert {
split_n_list[0],
tile_m_list[0],
tile_n_list[0],
group_size_list[0],
num_warps_list[0],
num_stages_list[0],
}
if split_n_list[0] is None:
split_n_list = [meta["SPLIT_N"] // 2, meta["SPLIT_N"], meta["SPLIT_N"] * 2][
int(meta["SPLIT_N"] == 1) :
]
elif split_n_list[0] == 0:
split_n_list = [meta["SPLIT_N"]]
if tile_m_list[0] is None:
tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][
int(meta["TILE_M"] == 16) :
]
elif tile_m_list[0] == 0:
tile_m_list = [meta["TILE_M"]]
if tile_n_list[0] is None:
tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][
int(meta["TILE_N"] == 16) :
]
elif tile_n_list[0] == 0:
tile_n_list = [meta["TILE_N"]]
if group_size_list[0] is None:
group_size_list = [
meta["GROUP_SIZE"] - 1,
meta["GROUP_SIZE"],
meta["GROUP_SIZE"] + 1,
][int(meta["GROUP_SIZE"] == 1) :]
elif group_size_list[0] == 0:
group_size_list = [meta["GROUP_SIZE"]]
if num_warps_list[0] is None:
num_warps_list = [
meta["num_warps"] // 2,
meta["num_warps"],
meta["num_warps"] * 2,
][int(meta["num_warps"] == 1) :]
elif num_warps_list[0] == 0:
num_warps_list = [meta["num_warps"]]
if num_stages_list[0] is None:
num_stages_list = [
meta["num_stages"] - 1,
meta["num_stages"],
meta["num_stages"] + 1,
][int(meta["num_stages"] == 1) :]
elif num_stages_list[0] == 0:
num_stages_list = [meta["num_stages"]]
device = args.device
dense_dense_mm_sizes = set()
target_performance = None
performance_rtol = 1e-2
best_messages = []
@atexit.register
def show_best_messages(best_messages=best_messages):
print("TOP 10:")
for m in best_messages[-10:]:
print(m)
sys.stdout.flush()
for m, k, n, bm, bk, sparsity in itertools.product(
m_list, n_list, k_list, bm_list, bk_list, sparsity_list
):
k = k or m
n = n or m
bk = bk or bm
if bm > m or bk > k:
continue
blocksize = (bm, bk)
if isinstance(sparsity, int):
# integer sparsity value corresponds to desired nnz value
sparsity = 1 - bk * bm * sparsity / (m * k)
if sparsity > 1 or sparsity < 0:
continue
x = create_blocked_tensor(
b, m, k, blocksize, sparsity, dtype, device
).to_sparse_bsr(blocksize)
# recompute sparsity
sparsity = 1 - bk * bm * x._nnz() / (m * k)
y = make_tensor(k, n, dtype=dtype, device=device)
bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}"
for op in ops:
if op == "dense_dense_mm":
if (m, k, n) in dense_dense_mm_sizes:
continue
dense_dense_mm_sizes.add((m, k, n))
best_tflops = 0
for (
split_n,
num_warps,
num_stages,
tile_m,
tile_n,
group_size,
) in itertools.product(
split_n_list,
num_warps_list,
num_stages_list,
tile_m_list,
tile_n_list,
group_size_list,
):
if (
(tile_m or 0) > bm
or (tile_n or 0) > n // (split_n or 1)
or n % (split_n or 1) != 0
or (split_n or 0) > n
):
continue
test_func = globals()["test_" + op]
meta = (
dict(
SPLIT_N=split_n,
TILE_M=tile_m,
TILE_N=tile_n,
GROUP_SIZE=group_size,
num_stages=num_stages,
num_warps=num_warps,
)
if op == "bsr_scatter_mm6"
else dict()
)
meta_str = ";".join(
f"{k}={v}" for k, v in meta.items() if v is not None
)
time_ms_lst = []
performance_tflops_lst = []
for r in range(args.repeat):
try:
time_ms, performance_tflops = test_func(x, y, **meta)
except triton.compiler.OutOfResources as msg:
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} OutOfResources",
file=outfile,
)
continue
except Exception as msg:
msg = str(msg).split("\n", 1)[0]
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} {msg}",
file=outfile,
)
continue
time_ms_lst.append(time_ms)
performance_tflops_lst.append(performance_tflops)
mark = ""
if op == "dense_dense_mm":
if target_performance is None:
target_performance = performance_tflops
elif target_performance is not None:
if (
abs(1 - performance_tflops / target_performance)
< performance_rtol
):
mark += " @@@"
if best_tflops < performance_tflops:
best_tflops = performance_tflops
best_message = (
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS"
)
if best_message not in best_messages:
best_messages.append(best_message)
mark += " !!!"
print(
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
f" blocksize={bm}x{bk}"
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}",
file=outfile,
)
outfile.flush()
if args.repeat > 1:
avg_time_ms = sum(time_ms_lst) / len(time_ms_lst)
avg_performance_tflops = sum(performance_tflops_lst) / len(
performance_tflops_lst
)
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk}"
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]",
file=outfile,
)
outfile.flush()
if op not in {"bsr_scatter_mm6"}:
break

View File

@ -3697,6 +3697,103 @@ class TestSparseCompressedTritonKernels(TestCase):
res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid)
self.assertEqual(res_tri, res_tri_grid)
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_scatter_mm(self, device, dtype):
from torch.sparse._triton_ops import scatter_mm
from functools import partial
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
sizes = [8, 16]
for m, k, n in itertools.product(sizes, sizes, sizes):
blocks = torch.stack([tensor(m, k), tensor(m, k)])
others = torch.stack([tensor(k, n), tensor(k, n)])
expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0],
blocks[0] @ others[1],
blocks[1] @ others[1]])
indices_data = (
'scatter_mm',
torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device),
torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device))
result = scatter_mm(blocks, others, indices_data=indices_data)
self.assertEqual(result, expected)
other = tensor(2 * k, 2 * n)
expected = torch.cat([
torch.cat([blocks[1], blocks[0]], dim=1),
torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other
indices_data = (
'bsr_strided_mm',
torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device),
torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device),
torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device),
torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n],
dtype=torch.int32, device=device),
dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1)
)
result = scatter_mm(blocks, other, indices_data=indices_data)
self.assertEqual(result, expected)
@parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_scatter_mm(self, device, dtype, blocksize):
import triton
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
from functools import partial
if isinstance(blocksize, str):
blocksize = tuple(map(int, blocksize.split('x')))
else:
blocksize = (blocksize,) * 2
# Note that each value in a non-zero block is in range blocksize * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [()]
sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]]
sizes_K = [blocksize[1], 2 * blocksize[1]]
for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches, sizes, sizes_K, sizes, (False, True)):
bsr_dense = tensor(bs + (M, K))
if has_zero_row_block:
if M > blocksize[0]:
bsr_dense[:blocksize[0]].zero_()
else:
continue
bsr = bsr_dense.to_sparse_bsr(blocksize)
dense = tensor(bd + (K, N))
expected = bsr.to_dense() @ dense
for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'):
if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}:
SPLIT_N_list = [N]
while SPLIT_N_list[-1] > 1:
SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2))
else:
SPLIT_N_list = [1]
for SPLIT_N in SPLIT_N_list:
indices_data = bsr_scatter_mm_indices_data(
bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N)
try:
result = bsr_scatter_mm(bsr, dense, indices_data=indices_data)
except triton.compiler.OutOfResources:
# ensure that there was at least one succesful test:
assert SPLIT_N < SPLIT_N_list[0]
break
self.assertEqual(result, expected)
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())

View File

@ -1,6 +1,5 @@
import math
import torch
from torch.utils._triton import has_triton
@ -209,6 +208,491 @@ def tile_to_blocksize(t, blocksize):
return t.view(new_shape).transpose(-3, -2)
def scatter_mm(blocks, others, indices_data, *, accumulators=None):
"""Scattered matrix multiplication of tensors.
A scattered matrix multiplication is defined as a series of matrix
multiplications applied to input tensors according to the input
and output mappings specified by indices data.
The following indices data formats are supported for defining a
scattered matrix multiplication operation (:attr:`indices_data[0]`
holds the name of the indices data format as specified below):
- ``"scatter_mm"`` - matrix multiplications scattered in batches
of tensors.
If :attr:`blocks` is a :math:`(* \times M \times K) tensor,
:attr:`others` is a :math:`(* \times K \times N)` tensor,
:attr:`accumulators` is a :math:`(* \times M \times N)` tensor,
and :attr:`indices = indices_data['indices']` is a :math:`(*
\times 3)` tensor, then the operation is equivalent to the
following code::
c_offsets, pq = indices_data[1:]
for r in range(len(c_offsets) - 1):
for g in range(c_offsets[r], c_offsets[r + 1]):
p, q = pq[g]
accumulators[r] += blocks[p] @ others[q]
- ``"bsr_strided_mm"`` - matrix multiplications scattered in
batches of tensors and a tensor.
If :attr:`blocks` is a :math:`(* \times Ms \times Ks) tensor,
:attr:`others` is a :math:`(K \times N)` tensor,
:attr:`accumulators` is a :math:`(M \times N)` tensor, then
the operation is equivalent to the following code::
c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
for i, r in enumerate(r_offsets):
r0, r1 = divmod(r, N)
for g in range(c_indices[i], c_indices[i+1]):
p = p_offsets[g]
q0, q1 = divmod(q_offsets[g], N)
accumulators[r0:r0 + Ms, r1:r1 + Ns] += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively.
- ``"bsr_strided_mm_compressed"`` - matrix multiplications
scattered in batches of tensors and a tensor. A memory and
processor efficient version of ``"bsr_strided_mm"`` format.
If :attr:`blocks` is a :math:`(* \times Ms \times Ks) tensor,
:attr:`others` is a :math:`(K \times N)` tensor,
:attr:`accumulators` is a :math:`(M \times N)` tensor, then
the operation is equivalent to the following code::
c_indices, r_offsets, q_offsets, meta = indices_data[1:]
for r in r_offsets:
m = (r // N) // Ms
n = (r % N) // Ns
r0, r1 = divmod(r, N)
c0, c1 = c_indices[m], c_indices[m + 1]
for i, p in enumerate(range(c0, c1)):
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
q0, q1 = divmod(q, N)
accumulators[r0:r0 + Ms, r1:r1 + Ns] += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
integer multiples of ``Ms`` and ``Ks``, respectively.
Notice that the order of ``r_offsets`` items can be arbitrary;
this property enables defining swizzle operators via
rearrangements of ``r_offsets`` items..
Auxilary functions are provided for pre-computing
:attr:`indices_data`. For example,
:func:`bsr_scatter_mm_indices_data` is used to define indices data
for matrix multiplication of BSR and strided tensors.
Parameters
----------
blocks (Tensor): a 3-D tensor of first matrices to be multiplied
others (Tensor): a tensor of second matrices to be multiplied. If
``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch
tensor of second input matrices to be multiplied. Otherwise, the
second input matrices are slices of the :attr:`others` tensor.
indices_data (tuple): a format data that defines the inputs and
outputs of scattered matrix multiplications.
Keyword arguments
-----------------
accumulators (Tensor, optional): a tensor of matrix product
accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor
is a 1-D batch tensor of output matrices. Otherwise, output
matrices are slices of the :attr:`accumulators` tensor.
"""
indices_format = indices_data[0]
assert blocks.ndim == 3
P, Ms, Ks = blocks.shape
if indices_format == 'scatter_mm':
c_offsets, pq = indices_data[1:]
assert others.ndim == 3
Q, Ks_, Ns = others.shape
assert Ks == Ks_
if accumulators is None:
R = c_offsets.shape[0] - 1
accumulators = torch.zeros((R, Ms, Ns), dtype=blocks.dtype, device=blocks.device)
else:
R, Ms_, Ns_ = accumulators.shape
assert Ms_ == Ms
assert Ns_ == Ns
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None:
for r in range(c_offsets.shape[0] - 1):
g0 = c_offsets[r]
g1 = c_offsets[r + 1]
for g in range(g0, g1):
p, q = pq[g]
accumulators[r] += blocks[p] @ others[q]
else:
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)
return accumulators
elif indices_format == 'bsr_strided_mm':
assert others.ndim == 2
K, N = others.shape
assert K % Ks == 0
c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
SPLIT_N = meta['SPLIT_N']
if accumulators is None:
M = Ms + (r_offsets.max().item() + 1) // N
accumulators = torch.zeros((M, N), dtype=blocks.dtype, device=blocks.device)
else:
M, N_ = accumulators.shape
assert N_ == N
Ns = N // SPLIT_N
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
accumulators.zero_()
for r in range(r_offsets.shape[0]):
r_ = r_offsets[r].item()
g0 = c_indices[r].item()
g1 = c_indices[r + 1].item()
r0, r1 = divmod(r_, N)
acc = accumulators[r0:r0 + Ms, r1:r1 + Ns]
for g in range(g0, g1):
p, q = p_offsets[g], q_offsets[g]
q0, q1 = divmod(q.item(), N)
acc += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
else:
_scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators)
return accumulators
elif indices_format == 'bsr_strided_mm_compressed':
assert others.ndim == 2
K, N = others.shape
assert K % Ks == 0
c_indices, r_offsets, q_offsets, meta = indices_data[1:]
SPLIT_N = meta['SPLIT_N']
if accumulators is None:
M = Ms + (r_offsets.max().item() + 1) // N
accumulators = torch.zeros((M, N), dtype=blocks.dtype, device=blocks.device)
else:
M, N_ = accumulators.shape
assert N_ == N
Ns = N // SPLIT_N
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
for j in range(len(r_offsets)):
r0, r1 = divmod(r_offsets[j].item(), N)
m = r0 // Ms
n = r1 // Ns
c0 = c_indices[m].item()
c1 = c_indices[m + 1].item()
acc = accumulators[r0:r0 + Ms, r1:r1 + Ns]
for i, p in enumerate(range(c0, c1)):
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item()
q0, q1 = divmod(q, N)
acc += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
else:
p_offsets = torch.empty((0, ), dtype=q_offsets.dtype, device=q_offsets.device)
_scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators)
return accumulators
else:
raise NotImplementedError(indices_format)
def scatter_mm_meta(M, K, N, Ms, Ks,
GROUP_SIZE=None, TILE_M=None, TILE_N=None, SPLIT_N=None, num_warps=None, num_stages=None, **extra):
if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}:
# The following parameters are optimized for the performance
# equilibrium points of bsr-dense and dense-dense matrix
# multiplications when using GPU cards NVIDIA A100 and NVIDIA
# GeForce RTX 2060 SUPER. For points far from the performance
# equilibrium points as well as for other GPU cards, the
# optimal parameters are likely different from what specified
# below.
device_name = torch.cuda.get_device_name()
is_A100 = 'A100' in device_name
if (M, K, N) == (256,) * 3:
if (Ms, Ks) == (16, 16):
SPLIT_N=1;TILE_M=16;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
SPLIT_N=2;TILE_M=32;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (M, K, N) == (512,) * 3:
if (Ms, Ks) == (16, 16):
SPLIT_N=8;TILE_M=16;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=1;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=4;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
SPLIT_N=4;TILE_M=32;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=1;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (M, K, N) == (1024,) * 3:
if (Ms, Ks) == (16, 16):
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=1;TILE_M=16;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=2;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=2;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (256, 256):
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (M, K, N) == (2048,) * 3:
if (Ms, Ks) == (16, 16):
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=8;num_stages=1;num_warps=1 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=8;TILE_M=16;TILE_N=64;GROUP_SIZE=1;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
SPLIT_N=4;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=1 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=16;TILE_M=32;TILE_N=64;GROUP_SIZE=1;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
SPLIT_N=4;TILE_M=64;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=32;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (256, 256):
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (M, K, N) == (4096,) * 3:
if (Ms, Ks) == (16, 16):
SPLIT_N=2;TILE_M=16;TILE_N=256;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
SPLIT_N=2;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=4;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=3;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
SPLIT_N=2;TILE_M=64;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if is_A100:
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=3;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
if is_A100:
SPLIT_N=2;TILE_M=128;TILE_N=128;GROUP_SIZE=1;num_stages=1;num_warps=8 # noqa: E225,E231,E702
elif (M, K, N) == (8192,) * 3:
if (Ms, Ks) == (16, 16):
if is_A100:
SPLIT_N=1;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
if is_A100:
SPLIT_N=1;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (64, 64):
if is_A100:
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (128, 128):
if is_A100:
SPLIT_N=4;TILE_M=128;TILE_N=128;GROUP_SIZE=2;num_stages=3;num_warps=8 # noqa: E225,E231,E702
elif (Ms, Ks) == (256, 256):
if is_A100:
SPLIT_N=8;TILE_M=256;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=16 # noqa: E225,E231,E702
elif (Ms, Ks) == (512, 512):
if is_A100:
SPLIT_N=1;TILE_M=128;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=8 # noqa: E225,E231,E702
elif (M, K, N) == (16384,) * 3:
if (Ms, Ks) == (16, 16):
if is_A100:
SPLIT_N=1;TILE_M=16;TILE_N=256;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
elif (Ms, Ks) == (32, 32):
if is_A100:
SPLIT_N=2;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
if SPLIT_N is None:
# Assume NVIDIA GeForce RTX 2060 SUPER:
# With the probality of 92% (99.9% when N > 512), the
# performance will not be worse more than 2% from the
# performance when using an optimal value. Otherwise, when N
# <= 512, using the following heuristics may give upto 15%
# lower performance.
SPLIT_N = {16: 1, 32: 2, 64: 4, 128: 8, 256: 16, 512: 8, 1024: 16, 4096: 32, 8192: 64}.get(N, 16)
if Ms >= 512 and N >= 2048:
SPLIT_N = 1
Ns = N // SPLIT_N
if TILE_M is None:
TILE_M = min(64 if Ns < 512 else 32, Ms)
if TILE_N is None:
TILE_N = min(64 if Ns < 512 else 32, Ns)
num_stages = num_stages or 1
if num_warps is None:
if min(M, N) > 1024:
num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
elif min(M, N) == 1024:
num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
elif min(M, N) == 256:
num_warps = {16: 1, 32: 4}.get(Ms, 4)
else:
num_warps = {16: 1, 32: 2}.get(Ms, 4)
GROUP_SIZE = GROUP_SIZE or 4
assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms)
assert TILE_N <= Ns, dict(TILE_B=TILE_N, Ns=Ns)
assert Ms <= M, dict(M=M, Ms=Ms)
assert Ns <= N, dict(N=N, Ns=Ns)
assert Ks <= K, dict(K=K, Ks=Ks)
return dict(TILE_M=TILE_M, TILE_N=TILE_N, GROUP_SIZE=GROUP_SIZE,
num_stages=num_stages, num_warps=num_warps, SPLIT_N=SPLIT_N, **extra)
def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed', **meta_input):
"""Computes indices data for :func:`scatter_mm` used in BSR and
strided tensor matrix multiplication.
"""
assert bsr.dense_dim() == 0
assert bsr.ndim == 2 # no batch dims
crow_indices = bsr.crow_indices()
col_indices = bsr.col_indices()
blocksize = bsr.values().shape[-2:]
M, K = bsr.shape
Ms, Ks = blocksize
K_, N = other.shape
assert K_ == K
meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input)
if 'allow_tf32' not in meta_input:
meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16})
if indices_format == 'bsr_strided_mm_compressed':
meta.update(is_compressed=True)
SPLIT_N = meta['SPLIT_N']
Ns = N // SPLIT_N
q_offsets_lst = []
b = torch.arange(SPLIT_N, dtype=torch.int32, device=bsr.device) * Ns
for m in range(M // Ms):
r0 = crow_indices[m].item()
r1 = crow_indices[m + 1].item()
if r1 == r0:
continue
q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0))
q_offsets = torch.cat(q_offsets_lst)
crow_indices_diff = crow_indices.diff()
non_zero_row_indices = crow_indices_diff.nonzero()
a = non_zero_row_indices * (Ms * N)
r_offsets = (a + b).view(-1)
c_indices = crow_indices
# swizzle operation: mm elements with longer sums are computed first:
nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N)
nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True)
r_offsets = r_offsets[indices]
return (indices_format, c_indices, r_offsets, q_offsets, meta)
elif indices_format == 'bsr_strided_mm':
meta.update(is_compressed=False)
SPLIT_N = meta['SPLIT_N']
Ns = N // SPLIT_N
p_offsets_lst = []
q_offsets_lst = []
b = torch.arange(SPLIT_N, dtype=torch.int32, device=bsr.device) * Ns
for m in range(M // Ms):
r0 = crow_indices[m].item()
r1 = crow_indices[m + 1].item()
if r1 == r0:
continue
p_offsets_lst.append(torch.arange(r0, r1, dtype=torch.int32, device=bsr.device).repeat(SPLIT_N))
q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0))
q_offsets = torch.cat(q_offsets_lst)
crow_indices_diff = crow_indices.diff()
non_zero_row_indices = crow_indices_diff.nonzero()
a = non_zero_row_indices * (Ms * N)
r_offsets = (a + b).view(-1)
c_indices = torch.cat((crow_indices[:1],
torch.cumsum(crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), 0)))
p_offsets = torch.cat(p_offsets_lst)
return (indices_format, c_indices, r_offsets, p_offsets, q_offsets, meta)
elif indices_format == 'scatter_mm':
Ns = Ms
c_indices = [0]
pq_offsets = []
# todo: eliminate inner for-loops for efficiency
for m in range(M // Ms):
r0 = crow_indices[m].item()
r1 = crow_indices[m + 1].item()
for n in range(N // Ns):
c_indices.append(c_indices[-1] + r1 - r0)
for t in range(r1 - r0):
p = r0 + t
q = col_indices[p].item() * (N // Ns) + n
pq_offsets.append([p, q])
return (indices_format,
torch.tensor(c_indices, dtype=torch.int32, device=crow_indices.device),
torch.tensor(pq_offsets, dtype=torch.int32, device=crow_indices.device))
raise NotImplementedError(indices_format)
def bsr_scatter_mm(bsr, other, indices_data=None):
"""BSR @ strided -> strided
"""
assert bsr.ndim == 2
assert other.ndim == 2
Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[1]
blocksize = bsr.values().shape[-2:]
if indices_data is None:
indices_data = bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed')
indices_format = indices_data[0]
if bsr._nnz() == 0:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
scatter_mm(bsr.values(), other, indices_data, accumulators=result)
elif indices_format == 'scatter_mm':
accumulators = torch.zeros((Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]),
dtype=bsr.dtype, device=bsr.device)
others = (other.transpose(0, 1)
.view(Ns // blocksize[0], blocksize[0], Ks // blocksize[1], blocksize[1])
.movedim((2, 0, 3, 1), (0, 1, 2, 3)) # equivalent to .transpose(1, 2).transpose(2, 3).transpose(0, 1)
.flatten(0, 1)
)
scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)
result = (accumulators
.unflatten(0, (Ms // blocksize[0], Ns // blocksize[0]))
.movedim((0, 1, 2, 3), (2, 0, 3, 1)) # equivalent to .transpose(0, 1).transpose(2, 3).transpose(1, 2)
.reshape(Ns, Ms)
.transpose(0, 1))
else:
raise NotImplementedError(indices_format)
return result
if has_triton():
import triton
import triton.language as tl
@ -887,8 +1371,217 @@ if has_triton():
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
sdpa = bsr_dense_mm(sdpa, value)
return sdpa
@triton.jit
def _scatter_mm2_kernel(
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K,
others_ptr, others_stride_Q, others_stride_K, others_stride_N,
accumulators_ptr, accumulators_stride_R, accumulators_stride_M, accumulators_stride_N,
pq_offsets_ptr, pq_offsets_stride,
pq_ptr, pq_stride_T, pq_stride_1,
dot_out_dtype: tl.constexpr,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
allow_tf32: tl.constexpr):
Ms = M // TILE_M
Ns = N // TILE_N
pid_t = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
pid_m = pid // Ms
pid_n = pid % Ms
rm = (pid_m * TILE_M + tl.arange(0, TILE_M))
rn = (pid_n * TILE_N + tl.arange(0, TILE_N))
rk = tl.arange(0, K)
A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K)
B_ptr = others_ptr + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride)
g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride)
if g0 == g1:
return
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
for i in range(g0, g1):
p = tl.load(pq_ptr + i * pq_stride_T)
q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1)
A = tl.load(A_ptr + p * blocks_stride_P)
B = tl.load(B_ptr + q * others_stride_Q)
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
C_ptr = accumulators_ptr + pid_t * accumulators_stride_R + (
rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N)
tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
def _scatter_mm2(
blocks: torch.Tensor,
others: torch.Tensor,
pq_offsets: torch.Tensor,
pq_indices: torch.Tensor,
accumulators: torch.Tensor
):
P, M, K = blocks.shape
Q, _, N = others.shape
R, _, _ = accumulators.shape
meta = dict(TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2)
def grid(META):
return (pq_offsets.shape[0] - 1, triton.cdiv(M, META['TILE_M']) * triton.cdiv(N, META['TILE_N']), 1)
dot_out_dtype = {torch.float16: tl.float32,
torch.bfloat16: tl.float32,
torch.float32: tl.float64,
torch.float64: tl.float64}[accumulators.dtype]
if 'allow_tf32' not in meta:
meta.update(allow_tf32=dot_out_dtype == tl.float32)
_scatter_mm2_kernel[grid](
M, K, N,
blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2),
others, others.stride(0), others.stride(1), others.stride(2),
accumulators, accumulators.stride(0), accumulators.stride(1), accumulators.stride(2),
pq_offsets, pq_offsets.stride(0),
pq_indices, pq_indices.stride(0), pq_indices.stride(1),
dot_out_dtype=dot_out_dtype,
**meta
)
@triton.jit
def _scatter_mm6_kernel(
Ms: tl.constexpr, Ks: tl.constexpr, N: tl.constexpr,
blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K,
others_ptr, others_stride_K, others_stride_N,
accumulators_ptr, accumulators_stride_M, accumulators_stride_N,
c_indices_ptr, r_offsets_ptr,
p_offsets_ptr, q_offsets_ptr,
is_compressed: tl.constexpr,
dot_out_dtype: tl.constexpr,
SPLIT_N: tl.constexpr,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
allow_tf32: tl.constexpr):
Ns = N // SPLIT_N
BLOCKS_M = Ms // TILE_M
BLOCKS_N = Ns // TILE_N
pid_t = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
num_pid_in_group = GROUP_SIZE * BLOCKS_N
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
rm = (pid_m * TILE_M + tl.arange(0, TILE_M))
rn = (pid_n * TILE_N + tl.arange(0, TILE_N))
rk = tl.arange(0, Ks)
A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K)
B_ptr = others_ptr + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
# When is_compressed is True, r is the only variable that
# depends on pid_t. This property allows sorting r values
# before calling the kernel. The sorting of r is equivalent to
# defining swizzle operator outside of the kernel.
r = tl.load(r_offsets_ptr + pid_t)
if is_compressed:
m = (r // N) // Ms
n = (r % N) // Ns
r0 = tl.load(c_indices_ptr + m)
r1 = tl.load(c_indices_ptr + m + 1)
g0 = n * r1 + (SPLIT_N - n) * r0
nnz = r1 - r0
else:
g0 = tl.load(c_indices_ptr + pid_t)
g1 = tl.load(c_indices_ptr + pid_t + 1)
nnz = g1 - g0
q_ptr = q_offsets_ptr + g0
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
if is_compressed:
A_ptr += r0 * blocks_stride_P
for _ in range(nnz):
q = tl.load(q_ptr)
B = tl.load(B_ptr + q)
A = tl.load(A_ptr)
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
A_ptr += blocks_stride_P
q_ptr += 1
else:
p_ptr = p_offsets_ptr + g0
for _ in range(nnz):
q = tl.load(q_ptr)
B = tl.load(B_ptr + q)
p = tl.load(p_ptr)
A = tl.load(A_ptr + p * blocks_stride_P)
p_ptr += 1
q_ptr += 1
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
C_ptr = accumulators_ptr + r + (
rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N)
tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
def _scatter_mm6(
blocks: torch.Tensor,
others: torch.Tensor,
c_indices: torch.Tensor,
r_offsets: torch.Tensor,
p_offsets: torch.Tensor,
q_offsets: torch.Tensor,
meta: dict,
accumulators: torch.Tensor
):
SPLIT_N = meta['SPLIT_N']
P, Ms, Ks = blocks.shape
K_, N = others.shape
M, N_ = accumulators.shape
assert N_ == N
Ns = N // SPLIT_N
def grid(META):
return (r_offsets.shape[0], triton.cdiv(Ms, META['TILE_M']) * triton.cdiv(Ns, META['TILE_N']))
dot_out_dtype = {torch.float16: tl.float32,
torch.bfloat16: tl.float32,
torch.float32: tl.float64,
torch.float64: tl.float64}[accumulators.dtype]
if 'allow_tf32' not in meta:
meta.update(allow_tf32=dot_out_dtype == tl.float32)
assert c_indices.stride(0) == 1
assert r_offsets.stride(0) == 1
assert p_offsets.stride(0) == 1
assert q_offsets.stride(0) == 1
_scatter_mm6_kernel[grid](
Ms, Ks, N,
blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2),
others, others.stride(0), others.stride(1),
accumulators, accumulators.stride(0), accumulators.stride(1),
c_indices,
r_offsets,
p_offsets,
q_offsets,
dot_out_dtype=dot_out_dtype,
**meta
)
else:
bsr_softmax = None # type: ignore[assignment]
bsr_dense_mm = None # type: ignore[assignment]
sampled_addmm = None # type: ignore[assignment]
_scaled_dot_product_attention = None # type: ignore[assignment]
_scatter_mm2 = None # type: ignore[assignment]
_scatter_mm6 = None # type: ignore[assignment]