mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add scatter_mm and bsr_scatter_mm operations. (#110396)
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise). The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`. <img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%"> The same figures for GPU card `NVIDIA A100-SXM4-80GB`: <img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%"> In sum: - `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`]. - `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`]. - `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`]. - However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110396 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
3b9246ba18
commit
d4708a6da7
388
benchmarks/sparse/triton_ops.py
Normal file
388
benchmarks/sparse/triton_ops.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
import torch
|
||||
|
||||
|
||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
assert (
|
||||
sparsity <= 1.0 and sparsity >= 0.0
|
||||
), "sparsity should be a value between 0 and 1"
|
||||
assert M % blocksize[0] == 0
|
||||
assert N % blocksize[1] == 0
|
||||
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
|
||||
A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device))
|
||||
expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1]))
|
||||
nonzero_indices = A.flatten().nonzero()
|
||||
actual_nnz = nonzero_indices.shape[0]
|
||||
if actual_nnz > expected_nnz:
|
||||
selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz]
|
||||
A.flatten()[nonzero_indices[selected_nonzeros]] = 0
|
||||
elif actual_nnz < expected_nnz:
|
||||
zero_indices = (A == 0).flatten().nonzero()
|
||||
selected_zeros = torch.randperm(zero_indices.shape[0])[
|
||||
: expected_nnz - actual_nnz
|
||||
]
|
||||
A.flatten()[zero_indices[selected_zeros]] = 1
|
||||
A = torch.repeat_interleave(A, blocksize[0], dim=-2)
|
||||
A = torch.repeat_interleave(A, blocksize[1], dim=-1)
|
||||
return A
|
||||
|
||||
|
||||
def _test_worker(test_func):
|
||||
import triton
|
||||
|
||||
ms, ms_min, ms_max = triton.testing.do_bench(
|
||||
test_func, warmup=500, rep=100, fast_flush=False
|
||||
)
|
||||
|
||||
tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3)
|
||||
return ms, tflops
|
||||
|
||||
|
||||
def test_dense_dense_mm(x, y, **meta):
|
||||
def test_func(x=x.to_dense(), y=y):
|
||||
return torch.matmul(x, y)
|
||||
|
||||
return _test_worker(test_func)
|
||||
|
||||
|
||||
def test_torch_matmul(x, y, **meta):
|
||||
def test_func(x=x, y=y):
|
||||
return torch.matmul(x, y)
|
||||
|
||||
return _test_worker(test_func)
|
||||
|
||||
|
||||
def test_bsr_dense_mm(x, y, **meta):
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
|
||||
def test_func(x=x, y=y):
|
||||
return bsr_dense_mm(x, y)
|
||||
|
||||
return _test_worker(test_func)
|
||||
|
||||
|
||||
def test_bsr_scatter_mm2(x, y, **meta):
|
||||
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
|
||||
|
||||
indices_data = bsr_scatter_mm_indices_data(
|
||||
x, y, indices_format="scatter_mm", **meta
|
||||
)
|
||||
|
||||
def test_func(x=x, y=y):
|
||||
return bsr_scatter_mm(x, y, indices_data=indices_data)
|
||||
|
||||
return _test_worker(test_func)
|
||||
|
||||
|
||||
def test_bsr_scatter_mm6(x, y, **meta):
|
||||
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
|
||||
|
||||
indices_data = bsr_scatter_mm_indices_data(
|
||||
x, y, indices_format="bsr_strided_mm_compressed", **meta
|
||||
)
|
||||
|
||||
def test_func(x=x, y=y):
|
||||
return bsr_scatter_mm(x, y, indices_data=indices_data)
|
||||
|
||||
return _test_worker(test_func)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import atexit
|
||||
import itertools
|
||||
import sys
|
||||
|
||||
import triton
|
||||
from torch.testing import make_tensor
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
def integer_list(a):
|
||||
return list(map(int, a.split(",")))
|
||||
|
||||
def float_list(a):
|
||||
return list(map(float, a.split(",")))
|
||||
|
||||
def integer_or_float_list(a):
|
||||
lst = []
|
||||
for n in a.split(","):
|
||||
if n.count(":") == 1:
|
||||
start, end = map(int, n.split(":"))
|
||||
lst.extend(range(start, end))
|
||||
elif n.count(":") == 2:
|
||||
start, end, step = map(int, n.split(":"))
|
||||
lst.extend(range(start, end, step))
|
||||
elif "." in n:
|
||||
lst.append(float(n))
|
||||
else:
|
||||
lst.append(int(n))
|
||||
return lst
|
||||
|
||||
parser = argparse.ArgumentParser(description="SpTritonOps")
|
||||
|
||||
parser.add_argument(
|
||||
"--ops",
|
||||
default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument("--b", default="0", type=int)
|
||||
|
||||
parser.add_argument("--m", default="1024", type=integer_list)
|
||||
parser.add_argument("--k", default=None, type=integer_list)
|
||||
parser.add_argument("--n", default=None, type=integer_list)
|
||||
parser.add_argument("--bm", default="16", type=integer_list)
|
||||
parser.add_argument("--bk", default=None, type=integer_list)
|
||||
parser.add_argument("--tile_m", default=None, type=integer_list)
|
||||
parser.add_argument("--tile_n", default=None, type=integer_list)
|
||||
parser.add_argument("--split_n", default=None, type=integer_list)
|
||||
parser.add_argument("--group_size", default=None, type=integer_list)
|
||||
parser.add_argument("--num_warps", default=None, type=integer_list)
|
||||
parser.add_argument("--num_stages", default=None, type=integer_list)
|
||||
parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list)
|
||||
parser.add_argument("--dtype", default="float16", type=str)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--repeat", default="1", type=int)
|
||||
parser.add_argument("--outfile", default="stdout", type=str)
|
||||
parser.add_argument("--star", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
|
||||
ops = args.ops.split(",")
|
||||
|
||||
b = args.b
|
||||
|
||||
m_list = args.m or [1024]
|
||||
n_list = args.n or [None]
|
||||
k_list = args.k or [None]
|
||||
bm_list = args.bm or [16]
|
||||
bk_list = args.bk or [None]
|
||||
split_n_list = args.split_n or [None]
|
||||
tile_m_list = args.tile_m or [None]
|
||||
tile_n_list = args.tile_n or [None]
|
||||
group_size_list = args.group_size or [None]
|
||||
num_warps_list = args.num_warps or [None]
|
||||
num_stages_list = args.num_stages or [None]
|
||||
sparsity_list = args.sparsity or [0.5]
|
||||
dtype = getattr(torch, args.dtype)
|
||||
|
||||
if args.star > 0:
|
||||
import torch.sparse._triton_ops
|
||||
|
||||
assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == {
|
||||
1
|
||||
}
|
||||
m = m_list[0]
|
||||
n = n_list[0] or m
|
||||
k = k_list[0] or m
|
||||
bm = bm_list[0]
|
||||
bk = bk_list[0] or bm
|
||||
meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk)
|
||||
assert {
|
||||
split_n_list[0],
|
||||
tile_m_list[0],
|
||||
tile_n_list[0],
|
||||
group_size_list[0],
|
||||
num_warps_list[0],
|
||||
num_stages_list[0],
|
||||
}
|
||||
if split_n_list[0] is None:
|
||||
split_n_list = [meta["SPLIT_N"] // 2, meta["SPLIT_N"], meta["SPLIT_N"] * 2][
|
||||
int(meta["SPLIT_N"] == 1) :
|
||||
]
|
||||
elif split_n_list[0] == 0:
|
||||
split_n_list = [meta["SPLIT_N"]]
|
||||
if tile_m_list[0] is None:
|
||||
tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][
|
||||
int(meta["TILE_M"] == 16) :
|
||||
]
|
||||
elif tile_m_list[0] == 0:
|
||||
tile_m_list = [meta["TILE_M"]]
|
||||
if tile_n_list[0] is None:
|
||||
tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][
|
||||
int(meta["TILE_N"] == 16) :
|
||||
]
|
||||
elif tile_n_list[0] == 0:
|
||||
tile_n_list = [meta["TILE_N"]]
|
||||
if group_size_list[0] is None:
|
||||
group_size_list = [
|
||||
meta["GROUP_SIZE"] - 1,
|
||||
meta["GROUP_SIZE"],
|
||||
meta["GROUP_SIZE"] + 1,
|
||||
][int(meta["GROUP_SIZE"] == 1) :]
|
||||
elif group_size_list[0] == 0:
|
||||
group_size_list = [meta["GROUP_SIZE"]]
|
||||
if num_warps_list[0] is None:
|
||||
num_warps_list = [
|
||||
meta["num_warps"] // 2,
|
||||
meta["num_warps"],
|
||||
meta["num_warps"] * 2,
|
||||
][int(meta["num_warps"] == 1) :]
|
||||
elif num_warps_list[0] == 0:
|
||||
num_warps_list = [meta["num_warps"]]
|
||||
if num_stages_list[0] is None:
|
||||
num_stages_list = [
|
||||
meta["num_stages"] - 1,
|
||||
meta["num_stages"],
|
||||
meta["num_stages"] + 1,
|
||||
][int(meta["num_stages"] == 1) :]
|
||||
elif num_stages_list[0] == 0:
|
||||
num_stages_list = [meta["num_stages"]]
|
||||
|
||||
device = args.device
|
||||
dense_dense_mm_sizes = set()
|
||||
target_performance = None
|
||||
performance_rtol = 1e-2
|
||||
|
||||
best_messages = []
|
||||
|
||||
@atexit.register
|
||||
def show_best_messages(best_messages=best_messages):
|
||||
print("TOP 10:")
|
||||
for m in best_messages[-10:]:
|
||||
print(m)
|
||||
sys.stdout.flush()
|
||||
|
||||
for m, k, n, bm, bk, sparsity in itertools.product(
|
||||
m_list, n_list, k_list, bm_list, bk_list, sparsity_list
|
||||
):
|
||||
k = k or m
|
||||
n = n or m
|
||||
bk = bk or bm
|
||||
|
||||
if bm > m or bk > k:
|
||||
continue
|
||||
|
||||
blocksize = (bm, bk)
|
||||
|
||||
if isinstance(sparsity, int):
|
||||
# integer sparsity value corresponds to desired nnz value
|
||||
sparsity = 1 - bk * bm * sparsity / (m * k)
|
||||
|
||||
if sparsity > 1 or sparsity < 0:
|
||||
continue
|
||||
|
||||
x = create_blocked_tensor(
|
||||
b, m, k, blocksize, sparsity, dtype, device
|
||||
).to_sparse_bsr(blocksize)
|
||||
|
||||
# recompute sparsity
|
||||
sparsity = 1 - bk * bm * x._nnz() / (m * k)
|
||||
|
||||
y = make_tensor(k, n, dtype=dtype, device=device)
|
||||
|
||||
bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}"
|
||||
|
||||
for op in ops:
|
||||
if op == "dense_dense_mm":
|
||||
if (m, k, n) in dense_dense_mm_sizes:
|
||||
continue
|
||||
dense_dense_mm_sizes.add((m, k, n))
|
||||
best_tflops = 0
|
||||
for (
|
||||
split_n,
|
||||
num_warps,
|
||||
num_stages,
|
||||
tile_m,
|
||||
tile_n,
|
||||
group_size,
|
||||
) in itertools.product(
|
||||
split_n_list,
|
||||
num_warps_list,
|
||||
num_stages_list,
|
||||
tile_m_list,
|
||||
tile_n_list,
|
||||
group_size_list,
|
||||
):
|
||||
if (
|
||||
(tile_m or 0) > bm
|
||||
or (tile_n or 0) > n // (split_n or 1)
|
||||
or n % (split_n or 1) != 0
|
||||
or (split_n or 0) > n
|
||||
):
|
||||
continue
|
||||
test_func = globals()["test_" + op]
|
||||
meta = (
|
||||
dict(
|
||||
SPLIT_N=split_n,
|
||||
TILE_M=tile_m,
|
||||
TILE_N=tile_n,
|
||||
GROUP_SIZE=group_size,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
if op == "bsr_scatter_mm6"
|
||||
else dict()
|
||||
)
|
||||
|
||||
meta_str = ";".join(
|
||||
f"{k}={v}" for k, v in meta.items() if v is not None
|
||||
)
|
||||
time_ms_lst = []
|
||||
performance_tflops_lst = []
|
||||
for r in range(args.repeat):
|
||||
try:
|
||||
time_ms, performance_tflops = test_func(x, y, **meta)
|
||||
except triton.compiler.OutOfResources as msg:
|
||||
print(
|
||||
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
|
||||
f" blocksize={bm}x{bk} OutOfResources",
|
||||
file=outfile,
|
||||
)
|
||||
continue
|
||||
except Exception as msg:
|
||||
msg = str(msg).split("\n", 1)[0]
|
||||
print(
|
||||
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
|
||||
f" blocksize={bm}x{bk} {msg}",
|
||||
file=outfile,
|
||||
)
|
||||
continue
|
||||
time_ms_lst.append(time_ms)
|
||||
performance_tflops_lst.append(performance_tflops)
|
||||
mark = ""
|
||||
if op == "dense_dense_mm":
|
||||
if target_performance is None:
|
||||
target_performance = performance_tflops
|
||||
elif target_performance is not None:
|
||||
if (
|
||||
abs(1 - performance_tflops / target_performance)
|
||||
< performance_rtol
|
||||
):
|
||||
mark += " @@@"
|
||||
if best_tflops < performance_tflops:
|
||||
best_tflops = performance_tflops
|
||||
best_message = (
|
||||
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
|
||||
f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS"
|
||||
)
|
||||
if best_message not in best_messages:
|
||||
best_messages.append(best_message)
|
||||
mark += " !!!"
|
||||
print(
|
||||
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
|
||||
f" blocksize={bm}x{bk}"
|
||||
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}",
|
||||
file=outfile,
|
||||
)
|
||||
outfile.flush()
|
||||
if args.repeat > 1:
|
||||
avg_time_ms = sum(time_ms_lst) / len(time_ms_lst)
|
||||
avg_performance_tflops = sum(performance_tflops_lst) / len(
|
||||
performance_tflops_lst
|
||||
)
|
||||
print(
|
||||
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
|
||||
f" blocksize={bm}x{bk}"
|
||||
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]",
|
||||
file=outfile,
|
||||
)
|
||||
outfile.flush()
|
||||
if op not in {"bsr_scatter_mm6"}:
|
||||
break
|
||||
|
|
@ -3697,6 +3697,103 @@ class TestSparseCompressedTritonKernels(TestCase):
|
|||
res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid)
|
||||
self.assertEqual(res_tri, res_tri_grid)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_scatter_mm(self, device, dtype):
|
||||
from torch.sparse._triton_ops import scatter_mm
|
||||
from functools import partial
|
||||
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
|
||||
sizes = [8, 16]
|
||||
for m, k, n in itertools.product(sizes, sizes, sizes):
|
||||
blocks = torch.stack([tensor(m, k), tensor(m, k)])
|
||||
others = torch.stack([tensor(k, n), tensor(k, n)])
|
||||
|
||||
expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0],
|
||||
blocks[0] @ others[1],
|
||||
blocks[1] @ others[1]])
|
||||
|
||||
indices_data = (
|
||||
'scatter_mm',
|
||||
torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device),
|
||||
torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device))
|
||||
|
||||
result = scatter_mm(blocks, others, indices_data=indices_data)
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
other = tensor(2 * k, 2 * n)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.cat([blocks[1], blocks[0]], dim=1),
|
||||
torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other
|
||||
|
||||
indices_data = (
|
||||
'bsr_strided_mm',
|
||||
torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device),
|
||||
torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device),
|
||||
torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device),
|
||||
torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n],
|
||||
dtype=torch.int32, device=device),
|
||||
dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1)
|
||||
)
|
||||
|
||||
result = scatter_mm(blocks, other, indices_data=indices_data)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64])
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_bsr_scatter_mm(self, device, dtype, blocksize):
|
||||
import triton
|
||||
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
|
||||
from functools import partial
|
||||
if isinstance(blocksize, str):
|
||||
blocksize = tuple(map(int, blocksize.split('x')))
|
||||
else:
|
||||
blocksize = (blocksize,) * 2
|
||||
# Note that each value in a non-zero block is in range blocksize * [low^2, high^2).
|
||||
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
|
||||
|
||||
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
|
||||
batches = [()]
|
||||
sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]]
|
||||
sizes_K = [blocksize[1], 2 * blocksize[1]]
|
||||
|
||||
for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches, sizes, sizes_K, sizes, (False, True)):
|
||||
bsr_dense = tensor(bs + (M, K))
|
||||
if has_zero_row_block:
|
||||
if M > blocksize[0]:
|
||||
bsr_dense[:blocksize[0]].zero_()
|
||||
else:
|
||||
continue
|
||||
bsr = bsr_dense.to_sparse_bsr(blocksize)
|
||||
dense = tensor(bd + (K, N))
|
||||
expected = bsr.to_dense() @ dense
|
||||
|
||||
for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'):
|
||||
if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}:
|
||||
SPLIT_N_list = [N]
|
||||
while SPLIT_N_list[-1] > 1:
|
||||
SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2))
|
||||
else:
|
||||
SPLIT_N_list = [1]
|
||||
for SPLIT_N in SPLIT_N_list:
|
||||
indices_data = bsr_scatter_mm_indices_data(
|
||||
bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N)
|
||||
try:
|
||||
result = bsr_scatter_mm(bsr, dense, indices_data=indices_data)
|
||||
except triton.compiler.OutOfResources:
|
||||
# ensure that there was at least one succesful test:
|
||||
assert SPLIT_N < SPLIT_N_list[0]
|
||||
break
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
|
||||
instantiate_device_type_tests(TestSparseCSR, globals())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
import torch
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
|
|
@ -209,6 +208,491 @@ def tile_to_blocksize(t, blocksize):
|
|||
return t.view(new_shape).transpose(-3, -2)
|
||||
|
||||
|
||||
def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
||||
"""Scattered matrix multiplication of tensors.
|
||||
|
||||
A scattered matrix multiplication is defined as a series of matrix
|
||||
multiplications applied to input tensors according to the input
|
||||
and output mappings specified by indices data.
|
||||
|
||||
The following indices data formats are supported for defining a
|
||||
scattered matrix multiplication operation (:attr:`indices_data[0]`
|
||||
holds the name of the indices data format as specified below):
|
||||
|
||||
- ``"scatter_mm"`` - matrix multiplications scattered in batches
|
||||
of tensors.
|
||||
|
||||
If :attr:`blocks` is a :math:`(* \times M \times K) tensor,
|
||||
:attr:`others` is a :math:`(* \times K \times N)` tensor,
|
||||
:attr:`accumulators` is a :math:`(* \times M \times N)` tensor,
|
||||
and :attr:`indices = indices_data['indices']` is a :math:`(*
|
||||
\times 3)` tensor, then the operation is equivalent to the
|
||||
following code::
|
||||
|
||||
c_offsets, pq = indices_data[1:]
|
||||
for r in range(len(c_offsets) - 1):
|
||||
for g in range(c_offsets[r], c_offsets[r + 1]):
|
||||
p, q = pq[g]
|
||||
accumulators[r] += blocks[p] @ others[q]
|
||||
|
||||
- ``"bsr_strided_mm"`` - matrix multiplications scattered in
|
||||
batches of tensors and a tensor.
|
||||
|
||||
If :attr:`blocks` is a :math:`(* \times Ms \times Ks) tensor,
|
||||
:attr:`others` is a :math:`(K \times N)` tensor,
|
||||
:attr:`accumulators` is a :math:`(M \times N)` tensor, then
|
||||
the operation is equivalent to the following code::
|
||||
|
||||
c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
|
||||
for i, r in enumerate(r_offsets):
|
||||
r0, r1 = divmod(r, N)
|
||||
for g in range(c_indices[i], c_indices[i+1]):
|
||||
p = p_offsets[g]
|
||||
q0, q1 = divmod(q_offsets[g], N)
|
||||
accumulators[r0:r0 + Ms, r1:r1 + Ns] += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
|
||||
|
||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||
|
||||
- ``"bsr_strided_mm_compressed"`` - matrix multiplications
|
||||
scattered in batches of tensors and a tensor. A memory and
|
||||
processor efficient version of ``"bsr_strided_mm"`` format.
|
||||
If :attr:`blocks` is a :math:`(* \times Ms \times Ks) tensor,
|
||||
:attr:`others` is a :math:`(K \times N)` tensor,
|
||||
:attr:`accumulators` is a :math:`(M \times N)` tensor, then
|
||||
the operation is equivalent to the following code::
|
||||
|
||||
c_indices, r_offsets, q_offsets, meta = indices_data[1:]
|
||||
for r in r_offsets:
|
||||
m = (r // N) // Ms
|
||||
n = (r % N) // Ns
|
||||
r0, r1 = divmod(r, N)
|
||||
c0, c1 = c_indices[m], c_indices[m + 1]
|
||||
for i, p in enumerate(range(c0, c1)):
|
||||
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
|
||||
q0, q1 = divmod(q, N)
|
||||
accumulators[r0:r0 + Ms, r1:r1 + Ns] += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
|
||||
|
||||
where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
|
||||
integer multiples of ``Ms`` and ``Ks``, respectively.
|
||||
|
||||
Notice that the order of ``r_offsets`` items can be arbitrary;
|
||||
this property enables defining swizzle operators via
|
||||
rearrangements of ``r_offsets`` items..
|
||||
|
||||
Auxilary functions are provided for pre-computing
|
||||
:attr:`indices_data`. For example,
|
||||
:func:`bsr_scatter_mm_indices_data` is used to define indices data
|
||||
for matrix multiplication of BSR and strided tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
blocks (Tensor): a 3-D tensor of first matrices to be multiplied
|
||||
|
||||
others (Tensor): a tensor of second matrices to be multiplied. If
|
||||
``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch
|
||||
tensor of second input matrices to be multiplied. Otherwise, the
|
||||
second input matrices are slices of the :attr:`others` tensor.
|
||||
indices_data (tuple): a format data that defines the inputs and
|
||||
outputs of scattered matrix multiplications.
|
||||
|
||||
Keyword arguments
|
||||
-----------------
|
||||
|
||||
accumulators (Tensor, optional): a tensor of matrix product
|
||||
accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor
|
||||
is a 1-D batch tensor of output matrices. Otherwise, output
|
||||
matrices are slices of the :attr:`accumulators` tensor.
|
||||
|
||||
"""
|
||||
indices_format = indices_data[0]
|
||||
|
||||
assert blocks.ndim == 3
|
||||
P, Ms, Ks = blocks.shape
|
||||
|
||||
if indices_format == 'scatter_mm':
|
||||
c_offsets, pq = indices_data[1:]
|
||||
|
||||
assert others.ndim == 3
|
||||
Q, Ks_, Ns = others.shape
|
||||
assert Ks == Ks_
|
||||
|
||||
if accumulators is None:
|
||||
R = c_offsets.shape[0] - 1
|
||||
accumulators = torch.zeros((R, Ms, Ns), dtype=blocks.dtype, device=blocks.device)
|
||||
else:
|
||||
R, Ms_, Ns_ = accumulators.shape
|
||||
assert Ms_ == Ms
|
||||
assert Ns_ == Ns
|
||||
|
||||
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None:
|
||||
for r in range(c_offsets.shape[0] - 1):
|
||||
g0 = c_offsets[r]
|
||||
g1 = c_offsets[r + 1]
|
||||
for g in range(g0, g1):
|
||||
p, q = pq[g]
|
||||
accumulators[r] += blocks[p] @ others[q]
|
||||
else:
|
||||
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)
|
||||
return accumulators
|
||||
|
||||
elif indices_format == 'bsr_strided_mm':
|
||||
|
||||
assert others.ndim == 2
|
||||
K, N = others.shape
|
||||
assert K % Ks == 0
|
||||
|
||||
c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
|
||||
SPLIT_N = meta['SPLIT_N']
|
||||
|
||||
if accumulators is None:
|
||||
M = Ms + (r_offsets.max().item() + 1) // N
|
||||
accumulators = torch.zeros((M, N), dtype=blocks.dtype, device=blocks.device)
|
||||
else:
|
||||
M, N_ = accumulators.shape
|
||||
assert N_ == N
|
||||
|
||||
Ns = N // SPLIT_N
|
||||
|
||||
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
|
||||
accumulators.zero_()
|
||||
for r in range(r_offsets.shape[0]):
|
||||
r_ = r_offsets[r].item()
|
||||
g0 = c_indices[r].item()
|
||||
g1 = c_indices[r + 1].item()
|
||||
r0, r1 = divmod(r_, N)
|
||||
acc = accumulators[r0:r0 + Ms, r1:r1 + Ns]
|
||||
for g in range(g0, g1):
|
||||
p, q = p_offsets[g], q_offsets[g]
|
||||
q0, q1 = divmod(q.item(), N)
|
||||
acc += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
|
||||
else:
|
||||
_scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators)
|
||||
return accumulators
|
||||
|
||||
elif indices_format == 'bsr_strided_mm_compressed':
|
||||
|
||||
assert others.ndim == 2
|
||||
K, N = others.shape
|
||||
assert K % Ks == 0
|
||||
|
||||
c_indices, r_offsets, q_offsets, meta = indices_data[1:]
|
||||
SPLIT_N = meta['SPLIT_N']
|
||||
|
||||
if accumulators is None:
|
||||
M = Ms + (r_offsets.max().item() + 1) // N
|
||||
accumulators = torch.zeros((M, N), dtype=blocks.dtype, device=blocks.device)
|
||||
else:
|
||||
M, N_ = accumulators.shape
|
||||
assert N_ == N
|
||||
|
||||
Ns = N // SPLIT_N
|
||||
|
||||
if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
|
||||
for j in range(len(r_offsets)):
|
||||
r0, r1 = divmod(r_offsets[j].item(), N)
|
||||
m = r0 // Ms
|
||||
n = r1 // Ns
|
||||
c0 = c_indices[m].item()
|
||||
c1 = c_indices[m + 1].item()
|
||||
acc = accumulators[r0:r0 + Ms, r1:r1 + Ns]
|
||||
for i, p in enumerate(range(c0, c1)):
|
||||
q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item()
|
||||
q0, q1 = divmod(q, N)
|
||||
acc += blocks[p] @ others[q0:q0 + Ks, q1:q1 + Ns]
|
||||
else:
|
||||
p_offsets = torch.empty((0, ), dtype=q_offsets.dtype, device=q_offsets.device)
|
||||
_scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators)
|
||||
return accumulators
|
||||
|
||||
else:
|
||||
raise NotImplementedError(indices_format)
|
||||
|
||||
|
||||
def scatter_mm_meta(M, K, N, Ms, Ks,
|
||||
GROUP_SIZE=None, TILE_M=None, TILE_N=None, SPLIT_N=None, num_warps=None, num_stages=None, **extra):
|
||||
if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}:
|
||||
# The following parameters are optimized for the performance
|
||||
# equilibrium points of bsr-dense and dense-dense matrix
|
||||
# multiplications when using GPU cards NVIDIA A100 and NVIDIA
|
||||
# GeForce RTX 2060 SUPER. For points far from the performance
|
||||
# equilibrium points as well as for other GPU cards, the
|
||||
# optimal parameters are likely different from what specified
|
||||
# below.
|
||||
device_name = torch.cuda.get_device_name()
|
||||
is_A100 = 'A100' in device_name
|
||||
if (M, K, N) == (256,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
SPLIT_N=2;TILE_M=32;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (512,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
SPLIT_N=8;TILE_M=16;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
SPLIT_N=4;TILE_M=32;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (1024,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=2;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=2;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (256, 256):
|
||||
SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (2048,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=8;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=8;TILE_M=16;TILE_N=64;GROUP_SIZE=1;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
SPLIT_N=4;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=16;TILE_M=32;TILE_N=64;GROUP_SIZE=1;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
SPLIT_N=4;TILE_M=64;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=32;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (256, 256):
|
||||
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (4096,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
SPLIT_N=2;TILE_M=16;TILE_N=256;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
SPLIT_N=2;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=3;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
SPLIT_N=2;TILE_M=64;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=3;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
if is_A100:
|
||||
SPLIT_N=2;TILE_M=128;TILE_N=128;GROUP_SIZE=1;num_stages=1;num_warps=8 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (8192,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (64, 64):
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (128, 128):
|
||||
if is_A100:
|
||||
SPLIT_N=4;TILE_M=128;TILE_N=128;GROUP_SIZE=2;num_stages=3;num_warps=8 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (256, 256):
|
||||
if is_A100:
|
||||
SPLIT_N=8;TILE_M=256;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=16 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (512, 512):
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=128;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=8 # noqa: E225,E231,E702
|
||||
elif (M, K, N) == (16384,) * 3:
|
||||
if (Ms, Ks) == (16, 16):
|
||||
if is_A100:
|
||||
SPLIT_N=1;TILE_M=16;TILE_N=256;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
elif (Ms, Ks) == (32, 32):
|
||||
if is_A100:
|
||||
SPLIT_N=2;TILE_M=32;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702
|
||||
|
||||
if SPLIT_N is None:
|
||||
# Assume NVIDIA GeForce RTX 2060 SUPER:
|
||||
# With the probality of 92% (99.9% when N > 512), the
|
||||
# performance will not be worse more than 2% from the
|
||||
# performance when using an optimal value. Otherwise, when N
|
||||
# <= 512, using the following heuristics may give upto 15%
|
||||
# lower performance.
|
||||
SPLIT_N = {16: 1, 32: 2, 64: 4, 128: 8, 256: 16, 512: 8, 1024: 16, 4096: 32, 8192: 64}.get(N, 16)
|
||||
if Ms >= 512 and N >= 2048:
|
||||
SPLIT_N = 1
|
||||
Ns = N // SPLIT_N
|
||||
if TILE_M is None:
|
||||
TILE_M = min(64 if Ns < 512 else 32, Ms)
|
||||
if TILE_N is None:
|
||||
TILE_N = min(64 if Ns < 512 else 32, Ns)
|
||||
num_stages = num_stages or 1
|
||||
if num_warps is None:
|
||||
if min(M, N) > 1024:
|
||||
num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
|
||||
elif min(M, N) == 1024:
|
||||
num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
|
||||
elif min(M, N) == 256:
|
||||
num_warps = {16: 1, 32: 4}.get(Ms, 4)
|
||||
else:
|
||||
num_warps = {16: 1, 32: 2}.get(Ms, 4)
|
||||
GROUP_SIZE = GROUP_SIZE or 4
|
||||
|
||||
assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms)
|
||||
assert TILE_N <= Ns, dict(TILE_B=TILE_N, Ns=Ns)
|
||||
assert Ms <= M, dict(M=M, Ms=Ms)
|
||||
assert Ns <= N, dict(N=N, Ns=Ns)
|
||||
assert Ks <= K, dict(K=K, Ks=Ks)
|
||||
|
||||
return dict(TILE_M=TILE_M, TILE_N=TILE_N, GROUP_SIZE=GROUP_SIZE,
|
||||
num_stages=num_stages, num_warps=num_warps, SPLIT_N=SPLIT_N, **extra)
|
||||
|
||||
|
||||
def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed', **meta_input):
|
||||
"""Computes indices data for :func:`scatter_mm` used in BSR and
|
||||
strided tensor matrix multiplication.
|
||||
"""
|
||||
assert bsr.dense_dim() == 0
|
||||
assert bsr.ndim == 2 # no batch dims
|
||||
crow_indices = bsr.crow_indices()
|
||||
col_indices = bsr.col_indices()
|
||||
blocksize = bsr.values().shape[-2:]
|
||||
M, K = bsr.shape
|
||||
Ms, Ks = blocksize
|
||||
K_, N = other.shape
|
||||
assert K_ == K
|
||||
|
||||
meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input)
|
||||
if 'allow_tf32' not in meta_input:
|
||||
meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16})
|
||||
|
||||
if indices_format == 'bsr_strided_mm_compressed':
|
||||
meta.update(is_compressed=True)
|
||||
SPLIT_N = meta['SPLIT_N']
|
||||
Ns = N // SPLIT_N
|
||||
q_offsets_lst = []
|
||||
b = torch.arange(SPLIT_N, dtype=torch.int32, device=bsr.device) * Ns
|
||||
for m in range(M // Ms):
|
||||
r0 = crow_indices[m].item()
|
||||
r1 = crow_indices[m + 1].item()
|
||||
if r1 == r0:
|
||||
continue
|
||||
q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0))
|
||||
q_offsets = torch.cat(q_offsets_lst)
|
||||
crow_indices_diff = crow_indices.diff()
|
||||
non_zero_row_indices = crow_indices_diff.nonzero()
|
||||
a = non_zero_row_indices * (Ms * N)
|
||||
r_offsets = (a + b).view(-1)
|
||||
c_indices = crow_indices
|
||||
# swizzle operation: mm elements with longer sums are computed first:
|
||||
nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N)
|
||||
nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True)
|
||||
r_offsets = r_offsets[indices]
|
||||
return (indices_format, c_indices, r_offsets, q_offsets, meta)
|
||||
elif indices_format == 'bsr_strided_mm':
|
||||
meta.update(is_compressed=False)
|
||||
SPLIT_N = meta['SPLIT_N']
|
||||
Ns = N // SPLIT_N
|
||||
p_offsets_lst = []
|
||||
q_offsets_lst = []
|
||||
b = torch.arange(SPLIT_N, dtype=torch.int32, device=bsr.device) * Ns
|
||||
for m in range(M // Ms):
|
||||
r0 = crow_indices[m].item()
|
||||
r1 = crow_indices[m + 1].item()
|
||||
if r1 == r0:
|
||||
continue
|
||||
p_offsets_lst.append(torch.arange(r0, r1, dtype=torch.int32, device=bsr.device).repeat(SPLIT_N))
|
||||
q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0))
|
||||
q_offsets = torch.cat(q_offsets_lst)
|
||||
crow_indices_diff = crow_indices.diff()
|
||||
non_zero_row_indices = crow_indices_diff.nonzero()
|
||||
a = non_zero_row_indices * (Ms * N)
|
||||
r_offsets = (a + b).view(-1)
|
||||
c_indices = torch.cat((crow_indices[:1],
|
||||
torch.cumsum(crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), 0)))
|
||||
p_offsets = torch.cat(p_offsets_lst)
|
||||
return (indices_format, c_indices, r_offsets, p_offsets, q_offsets, meta)
|
||||
|
||||
elif indices_format == 'scatter_mm':
|
||||
Ns = Ms
|
||||
c_indices = [0]
|
||||
pq_offsets = []
|
||||
# todo: eliminate inner for-loops for efficiency
|
||||
for m in range(M // Ms):
|
||||
r0 = crow_indices[m].item()
|
||||
r1 = crow_indices[m + 1].item()
|
||||
for n in range(N // Ns):
|
||||
c_indices.append(c_indices[-1] + r1 - r0)
|
||||
for t in range(r1 - r0):
|
||||
p = r0 + t
|
||||
q = col_indices[p].item() * (N // Ns) + n
|
||||
pq_offsets.append([p, q])
|
||||
|
||||
return (indices_format,
|
||||
torch.tensor(c_indices, dtype=torch.int32, device=crow_indices.device),
|
||||
torch.tensor(pq_offsets, dtype=torch.int32, device=crow_indices.device))
|
||||
|
||||
raise NotImplementedError(indices_format)
|
||||
|
||||
|
||||
def bsr_scatter_mm(bsr, other, indices_data=None):
|
||||
"""BSR @ strided -> strided
|
||||
"""
|
||||
|
||||
assert bsr.ndim == 2
|
||||
assert other.ndim == 2
|
||||
|
||||
Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[1]
|
||||
blocksize = bsr.values().shape[-2:]
|
||||
|
||||
if indices_data is None:
|
||||
indices_data = bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed')
|
||||
|
||||
indices_format = indices_data[0]
|
||||
|
||||
if bsr._nnz() == 0:
|
||||
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
|
||||
elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}:
|
||||
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
|
||||
scatter_mm(bsr.values(), other, indices_data, accumulators=result)
|
||||
elif indices_format == 'scatter_mm':
|
||||
accumulators = torch.zeros((Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]),
|
||||
dtype=bsr.dtype, device=bsr.device)
|
||||
|
||||
others = (other.transpose(0, 1)
|
||||
.view(Ns // blocksize[0], blocksize[0], Ks // blocksize[1], blocksize[1])
|
||||
.movedim((2, 0, 3, 1), (0, 1, 2, 3)) # equivalent to .transpose(1, 2).transpose(2, 3).transpose(0, 1)
|
||||
.flatten(0, 1)
|
||||
)
|
||||
|
||||
scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)
|
||||
|
||||
result = (accumulators
|
||||
.unflatten(0, (Ms // blocksize[0], Ns // blocksize[0]))
|
||||
.movedim((0, 1, 2, 3), (2, 0, 3, 1)) # equivalent to .transpose(0, 1).transpose(2, 3).transpose(1, 2)
|
||||
.reshape(Ns, Ms)
|
||||
.transpose(0, 1))
|
||||
else:
|
||||
raise NotImplementedError(indices_format)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
|
@ -887,8 +1371,217 @@ if has_triton():
|
|||
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
||||
sdpa = bsr_dense_mm(sdpa, value)
|
||||
return sdpa
|
||||
|
||||
@triton.jit
|
||||
def _scatter_mm2_kernel(
|
||||
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
|
||||
blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K,
|
||||
others_ptr, others_stride_Q, others_stride_K, others_stride_N,
|
||||
accumulators_ptr, accumulators_stride_R, accumulators_stride_M, accumulators_stride_N,
|
||||
pq_offsets_ptr, pq_offsets_stride,
|
||||
pq_ptr, pq_stride_T, pq_stride_1,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
TILE_M: tl.constexpr,
|
||||
TILE_N: tl.constexpr,
|
||||
allow_tf32: tl.constexpr):
|
||||
|
||||
Ms = M // TILE_M
|
||||
Ns = N // TILE_N
|
||||
|
||||
pid_t = tl.program_id(axis=0)
|
||||
|
||||
pid = tl.program_id(axis=1)
|
||||
pid_m = pid // Ms
|
||||
pid_n = pid % Ms
|
||||
|
||||
rm = (pid_m * TILE_M + tl.arange(0, TILE_M))
|
||||
rn = (pid_n * TILE_N + tl.arange(0, TILE_N))
|
||||
rk = tl.arange(0, K)
|
||||
|
||||
A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K)
|
||||
B_ptr = others_ptr + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
|
||||
|
||||
g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride)
|
||||
g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride)
|
||||
|
||||
if g0 == g1:
|
||||
return
|
||||
|
||||
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
|
||||
|
||||
for i in range(g0, g1):
|
||||
p = tl.load(pq_ptr + i * pq_stride_T)
|
||||
q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1)
|
||||
A = tl.load(A_ptr + p * blocks_stride_P)
|
||||
B = tl.load(B_ptr + q * others_stride_Q)
|
||||
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
|
||||
C_ptr = accumulators_ptr + pid_t * accumulators_stride_R + (
|
||||
rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N)
|
||||
tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
|
||||
|
||||
def _scatter_mm2(
|
||||
blocks: torch.Tensor,
|
||||
others: torch.Tensor,
|
||||
pq_offsets: torch.Tensor,
|
||||
pq_indices: torch.Tensor,
|
||||
accumulators: torch.Tensor
|
||||
):
|
||||
P, M, K = blocks.shape
|
||||
Q, _, N = others.shape
|
||||
R, _, _ = accumulators.shape
|
||||
|
||||
meta = dict(TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2)
|
||||
|
||||
def grid(META):
|
||||
return (pq_offsets.shape[0] - 1, triton.cdiv(M, META['TILE_M']) * triton.cdiv(N, META['TILE_N']), 1)
|
||||
|
||||
dot_out_dtype = {torch.float16: tl.float32,
|
||||
torch.bfloat16: tl.float32,
|
||||
torch.float32: tl.float64,
|
||||
torch.float64: tl.float64}[accumulators.dtype]
|
||||
if 'allow_tf32' not in meta:
|
||||
meta.update(allow_tf32=dot_out_dtype == tl.float32)
|
||||
_scatter_mm2_kernel[grid](
|
||||
M, K, N,
|
||||
blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2),
|
||||
others, others.stride(0), others.stride(1), others.stride(2),
|
||||
accumulators, accumulators.stride(0), accumulators.stride(1), accumulators.stride(2),
|
||||
pq_offsets, pq_offsets.stride(0),
|
||||
pq_indices, pq_indices.stride(0), pq_indices.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
**meta
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def _scatter_mm6_kernel(
|
||||
Ms: tl.constexpr, Ks: tl.constexpr, N: tl.constexpr,
|
||||
blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K,
|
||||
others_ptr, others_stride_K, others_stride_N,
|
||||
accumulators_ptr, accumulators_stride_M, accumulators_stride_N,
|
||||
c_indices_ptr, r_offsets_ptr,
|
||||
p_offsets_ptr, q_offsets_ptr,
|
||||
is_compressed: tl.constexpr,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
TILE_M: tl.constexpr,
|
||||
TILE_N: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr):
|
||||
Ns = N // SPLIT_N
|
||||
BLOCKS_M = Ms // TILE_M
|
||||
BLOCKS_N = Ns // TILE_N
|
||||
|
||||
pid_t = tl.program_id(axis=0)
|
||||
pid = tl.program_id(axis=1)
|
||||
|
||||
num_pid_in_group = GROUP_SIZE * BLOCKS_N
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE
|
||||
group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
rm = (pid_m * TILE_M + tl.arange(0, TILE_M))
|
||||
rn = (pid_n * TILE_N + tl.arange(0, TILE_N))
|
||||
rk = tl.arange(0, Ks)
|
||||
A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K)
|
||||
B_ptr = others_ptr + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
|
||||
|
||||
# When is_compressed is True, r is the only variable that
|
||||
# depends on pid_t. This property allows sorting r values
|
||||
# before calling the kernel. The sorting of r is equivalent to
|
||||
# defining swizzle operator outside of the kernel.
|
||||
r = tl.load(r_offsets_ptr + pid_t)
|
||||
|
||||
if is_compressed:
|
||||
m = (r // N) // Ms
|
||||
n = (r % N) // Ns
|
||||
r0 = tl.load(c_indices_ptr + m)
|
||||
r1 = tl.load(c_indices_ptr + m + 1)
|
||||
g0 = n * r1 + (SPLIT_N - n) * r0
|
||||
nnz = r1 - r0
|
||||
else:
|
||||
g0 = tl.load(c_indices_ptr + pid_t)
|
||||
g1 = tl.load(c_indices_ptr + pid_t + 1)
|
||||
nnz = g1 - g0
|
||||
|
||||
q_ptr = q_offsets_ptr + g0
|
||||
acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
|
||||
|
||||
if is_compressed:
|
||||
A_ptr += r0 * blocks_stride_P
|
||||
for _ in range(nnz):
|
||||
q = tl.load(q_ptr)
|
||||
B = tl.load(B_ptr + q)
|
||||
A = tl.load(A_ptr)
|
||||
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
A_ptr += blocks_stride_P
|
||||
q_ptr += 1
|
||||
else:
|
||||
p_ptr = p_offsets_ptr + g0
|
||||
for _ in range(nnz):
|
||||
q = tl.load(q_ptr)
|
||||
B = tl.load(B_ptr + q)
|
||||
p = tl.load(p_ptr)
|
||||
A = tl.load(A_ptr + p * blocks_stride_P)
|
||||
p_ptr += 1
|
||||
q_ptr += 1
|
||||
acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
|
||||
|
||||
C_ptr = accumulators_ptr + r + (
|
||||
rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N)
|
||||
tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
|
||||
|
||||
def _scatter_mm6(
|
||||
blocks: torch.Tensor,
|
||||
others: torch.Tensor,
|
||||
c_indices: torch.Tensor,
|
||||
r_offsets: torch.Tensor,
|
||||
p_offsets: torch.Tensor,
|
||||
q_offsets: torch.Tensor,
|
||||
meta: dict,
|
||||
accumulators: torch.Tensor
|
||||
):
|
||||
SPLIT_N = meta['SPLIT_N']
|
||||
P, Ms, Ks = blocks.shape
|
||||
K_, N = others.shape
|
||||
M, N_ = accumulators.shape
|
||||
assert N_ == N
|
||||
Ns = N // SPLIT_N
|
||||
|
||||
def grid(META):
|
||||
return (r_offsets.shape[0], triton.cdiv(Ms, META['TILE_M']) * triton.cdiv(Ns, META['TILE_N']))
|
||||
|
||||
dot_out_dtype = {torch.float16: tl.float32,
|
||||
torch.bfloat16: tl.float32,
|
||||
torch.float32: tl.float64,
|
||||
torch.float64: tl.float64}[accumulators.dtype]
|
||||
if 'allow_tf32' not in meta:
|
||||
meta.update(allow_tf32=dot_out_dtype == tl.float32)
|
||||
|
||||
assert c_indices.stride(0) == 1
|
||||
assert r_offsets.stride(0) == 1
|
||||
assert p_offsets.stride(0) == 1
|
||||
assert q_offsets.stride(0) == 1
|
||||
|
||||
_scatter_mm6_kernel[grid](
|
||||
Ms, Ks, N,
|
||||
blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2),
|
||||
others, others.stride(0), others.stride(1),
|
||||
accumulators, accumulators.stride(0), accumulators.stride(1),
|
||||
c_indices,
|
||||
r_offsets,
|
||||
p_offsets,
|
||||
q_offsets,
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
**meta
|
||||
)
|
||||
|
||||
else:
|
||||
bsr_softmax = None # type: ignore[assignment]
|
||||
bsr_dense_mm = None # type: ignore[assignment]
|
||||
sampled_addmm = None # type: ignore[assignment]
|
||||
_scaled_dot_product_attention = None # type: ignore[assignment]
|
||||
_scatter_mm2 = None # type: ignore[assignment]
|
||||
_scatter_mm6 = None # type: ignore[assignment]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user