mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
it without creating cyclic dependencies Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832 Approved by: https://github.com/zou3519
895 lines
31 KiB
Python
895 lines
31 KiB
Python
import math
|
|
import torch
|
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
def check(cond, msg):
|
|
if not cond:
|
|
raise ValueError(msg)
|
|
|
|
|
|
def check_bsr_layout(f_name, t):
|
|
check(
|
|
t.layout == torch.sparse_bsr,
|
|
f"{f_name}(): only BSR sparse format is supported for the sparse argument.",
|
|
)
|
|
|
|
|
|
def check_device(f_name, t, device):
|
|
check(
|
|
t.device == device and t.device.type == "cuda",
|
|
f"{f_name}(): all inputs are expected to be on the same GPU device.",
|
|
)
|
|
|
|
|
|
def check_mm_compatible_shapes(f_name, lhs, rhs):
|
|
check(
|
|
lhs.dim() >= 2 and rhs.dim() >= 2,
|
|
f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, "
|
|
f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}."
|
|
)
|
|
|
|
m, kl = lhs.shape[-2:]
|
|
kr, n = rhs.shape[-2:]
|
|
|
|
check(
|
|
kl == kr,
|
|
f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, "
|
|
f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.",
|
|
)
|
|
|
|
|
|
def check_dtype(f_name, t, dtype, *additional_dtypes):
|
|
check(
|
|
t.dtype == dtype
|
|
and t.dtype in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)),
|
|
f"{f_name}(): all inputs are expected to be of the same dtype "
|
|
f"and one of (half, bfloat16, float32) or {additional_dtypes}, "
|
|
f"but got dtype == {t.dtype}.",
|
|
)
|
|
|
|
|
|
def check_blocksize(f_name, blocksize):
|
|
assert len(blocksize) == 2
|
|
|
|
def is_power_of_two(v):
|
|
return not (v & (v - 1))
|
|
|
|
def is_compatible_blocksize(b):
|
|
res = True
|
|
for blocksize in b:
|
|
# Triton loads only blocks which are at least 16 and powers of 2.
|
|
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
|
|
return res
|
|
|
|
check(
|
|
is_compatible_blocksize(blocksize),
|
|
f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) "
|
|
"should be at least 16 and a power of 2 in each dimension.",
|
|
)
|
|
|
|
|
|
def make_triton_contiguous(t):
|
|
if (t.stride(-2) > 1 or t.dtype is torch.float32) and t.stride(-1) > 1:
|
|
return t.contiguous()
|
|
else:
|
|
return t
|
|
|
|
|
|
def broadcast_batch_dims(f_name, *tensors):
|
|
try:
|
|
return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors))
|
|
except Exception:
|
|
check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!")
|
|
|
|
|
|
def slicer(dim, slice_range, *tensors):
|
|
for t in tensors:
|
|
slices = [slice(None)] * t.dim()
|
|
slices[dim] = slice_range
|
|
yield t[slices]
|
|
|
|
|
|
def multidim_slicer(dims, slices, *tensors):
|
|
for t in tensors:
|
|
s = [slice(None)] * t.dim()
|
|
for d, d_slice in zip(dims, slices):
|
|
if d is not None:
|
|
s[d] = d_slice
|
|
yield t[s]
|
|
|
|
|
|
def ptr_stride_extractor(*tensors):
|
|
for t in tensors:
|
|
yield t
|
|
yield from t.stride()
|
|
|
|
|
|
def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
|
|
assert 0 <= len(full_grid) <= 3
|
|
assert 0 <= len(grid_blocks) <= 3
|
|
|
|
import itertools
|
|
|
|
def generate_grid_points():
|
|
for fg, mg in zip(full_grid, grid_blocks):
|
|
yield range(0, fg, mg)
|
|
|
|
def generate_sliced_tensors(slices):
|
|
for t, t_dims in tensor_dims_map.items():
|
|
yield next(multidim_slicer(t_dims, slices, t))
|
|
|
|
for grid_point in itertools.product(*generate_grid_points()):
|
|
grid = [min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)]
|
|
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
|
|
# grid_points are iterated in a "contiguous" order, i.e.
|
|
# left dimensions traversed slower than right dimensions.
|
|
# This order is reversed for CUDA grids.
|
|
yield grid[::-1], *generate_sliced_tensors(slices)
|
|
|
|
|
|
def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
|
|
# cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
|
|
cuda_max_grid = (2147483647, 65535, 65535)[::-1]
|
|
if grid_blocks is None:
|
|
grid_blocks = cuda_max_grid
|
|
else:
|
|
|
|
def valid_grid_dim(g, mg):
|
|
if g is None:
|
|
return mg
|
|
else:
|
|
# grid must be at least 1 and no greater than mg
|
|
return max(1, min(g, mg))
|
|
|
|
grid_blocks = tuple(
|
|
valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
|
|
) # type: ignore[assignment]
|
|
|
|
for grid, *sliced_tensors in grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
|
|
kernel(grid, *sliced_tensors)
|
|
|
|
|
|
def prepare_inputs(bsr, *dense_tensors):
|
|
# Introduce fake batch dimension if not present for convenience.
|
|
crow_indices = bsr.crow_indices().unsqueeze(0)
|
|
col_indices = bsr.col_indices().unsqueeze(0)
|
|
values = make_triton_contiguous(bsr.values().unsqueeze(0))
|
|
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
|
|
|
|
# Compute broadcasted batch dimension
|
|
batch_dims_broadcasted = torch.broadcast_shapes(values.shape[:-3], *(t.shape[:-2] for t in tensors))
|
|
|
|
# Broadcast batch dimensions and squash.
|
|
# The result can be either a view or a copy.
|
|
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
|
|
return t.broadcast_to(batch_dims + invariant_dims).flatten(
|
|
0, len(batch_dims) - 1
|
|
)
|
|
|
|
crow_indices = batch_broadcast_and_squash(
|
|
crow_indices, batch_dims_broadcasted, (-1,)
|
|
)
|
|
|
|
col_indices = batch_broadcast_and_squash(
|
|
col_indices, batch_dims_broadcasted, (-1,)
|
|
)
|
|
values = batch_broadcast_and_squash(
|
|
values, batch_dims_broadcasted, values.shape[-3:]
|
|
)
|
|
tensors = [
|
|
batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) for t in tensors
|
|
]
|
|
|
|
return crow_indices, col_indices, values, *tensors
|
|
|
|
|
|
def broadcast_batch_dims_bsr(f_name, bsr, *tensors):
|
|
batch_shape = broadcast_batch_dims(f_name, bsr, *tensors)
|
|
|
|
crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,))
|
|
col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,))
|
|
values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:])
|
|
size = batch_shape + bsr.shape[-2:]
|
|
return torch.sparse_compressed_tensor(crow_indices, col_indices, values, size=size, layout=bsr.layout)
|
|
|
|
|
|
# NOTE: this function will ALWAYS create a view
|
|
def tile_to_blocksize(t, blocksize):
|
|
*rest, m, n = t.shape
|
|
new_shape = rest + [
|
|
m // blocksize[0],
|
|
blocksize[0],
|
|
n // blocksize[1],
|
|
blocksize[1],
|
|
]
|
|
# using .view instead of .reshape to ensure that the result is
|
|
# indeed a view:
|
|
return t.view(new_shape).transpose(-3, -2)
|
|
|
|
|
|
if has_triton():
|
|
import triton
|
|
import triton.language as tl
|
|
from typing import Optional, Tuple
|
|
|
|
@triton.jit
|
|
def _sampled_addmm_kernel(
|
|
alpha,
|
|
beta,
|
|
IS_BETA_ZERO: tl.constexpr,
|
|
BLOCKSIZE_ROW: tl.constexpr,
|
|
BLOCKSIZE_COL: tl.constexpr,
|
|
k,
|
|
TILE_K: tl.constexpr,
|
|
values_ptr,
|
|
values_batch_stride,
|
|
values_nnz_stride,
|
|
values_row_block_stride,
|
|
values_col_block_stride,
|
|
crow_indices_ptr,
|
|
crow_indices_batch_stride,
|
|
crow_indices_stride,
|
|
col_indices_ptr,
|
|
col_indices_batch_stride,
|
|
col_indices_stride,
|
|
mat1_ptr,
|
|
mat1_batch_stride,
|
|
mat1_tiled_row_stride,
|
|
mat1_tiled_col_stride,
|
|
mat1_row_block_stride,
|
|
mat1_col_block_stride,
|
|
mat2_ptr,
|
|
mat2_batch_stride,
|
|
mat2_tiled_row_stride,
|
|
mat2_tiled_col_stride,
|
|
mat2_row_block_stride,
|
|
mat2_col_block_stride,
|
|
acc_dtype: tl.constexpr,
|
|
allow_tf32: tl.constexpr,
|
|
):
|
|
batch_pid = tl.program_id(axis=1)
|
|
row_block_pid = tl.program_id(axis=0)
|
|
|
|
crow_indices_offset_ptr = (
|
|
crow_indices_ptr
|
|
+ crow_indices_batch_stride * batch_pid
|
|
+ crow_indices_stride * row_block_pid
|
|
)
|
|
nnz_offset = tl.load(crow_indices_offset_ptr)
|
|
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
|
|
|
|
# Compute nnz for the row with number row_block_pid.
|
|
# If it is zero, skip the row.
|
|
row_nnz = nnz_offset_next - nnz_offset
|
|
if row_nnz == 0:
|
|
return
|
|
|
|
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
|
|
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
|
|
|
|
# Pointers are set to the first block of the current row.
|
|
values_block_ptrs = (
|
|
values_ptr
|
|
+ values_batch_stride * batch_pid
|
|
+ values_nnz_stride * nnz_offset
|
|
+ values_row_block_stride * row_block_arange[:, None]
|
|
+ values_col_block_stride * col_block_arange[None, :]
|
|
)
|
|
|
|
col_index_nnz_ptr = (
|
|
col_indices_ptr
|
|
+ col_indices_batch_stride * batch_pid
|
|
+ col_indices_stride * nnz_offset
|
|
)
|
|
|
|
# Advance mat1 to the current tiled row, ignore columns.
|
|
mat1_block_ptrs = (
|
|
mat1_ptr
|
|
+ mat1_batch_stride * batch_pid
|
|
+ mat1_tiled_row_stride * row_block_pid
|
|
+ mat1_row_block_stride * row_block_arange[:, None]
|
|
)
|
|
|
|
# Advance mat2 in batch and block col dimension.
|
|
mat2_block_ptrs = (
|
|
mat2_ptr
|
|
+ mat2_batch_stride * batch_pid
|
|
+ mat2_col_block_stride * col_block_arange[None, :]
|
|
)
|
|
|
|
k_tile_arange = tl.arange(0, TILE_K)
|
|
for _ in range(row_nnz):
|
|
acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
|
|
|
|
# find column block index
|
|
col_block = tl.load(col_index_nnz_ptr)
|
|
|
|
for k_tile in range(0, k, TILE_K):
|
|
k_offsets = k_tile + k_tile_arange
|
|
mask_k = k_offsets < k
|
|
|
|
mat1_block = tl.load(
|
|
mat1_block_ptrs
|
|
+ mat1_col_block_stride * k_offsets[None, :],
|
|
mask=mask_k[None, :], other=0.0
|
|
)
|
|
|
|
mat2_block = tl.load(
|
|
mat2_block_ptrs
|
|
+ mat2_tiled_col_stride * col_block
|
|
+ mat2_row_block_stride * k_offsets[:, None],
|
|
mask=mask_k[:, None], other=0.0
|
|
)
|
|
|
|
acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)
|
|
|
|
if IS_BETA_ZERO:
|
|
acc_block *= alpha
|
|
else:
|
|
acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)
|
|
|
|
# write result
|
|
tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))
|
|
|
|
# advance val/col_index ptrs to the next block in the row.
|
|
values_block_ptrs += values_nnz_stride
|
|
col_index_nnz_ptr += col_indices_stride
|
|
|
|
@triton.jit
|
|
def _bsr_strided_dense_rowspace_kernel(
|
|
BLOCKSIZE_ROW: tl.constexpr,
|
|
BLOCKSIZE_COL: tl.constexpr,
|
|
# values prologue
|
|
values_ptr,
|
|
values_batch_stride,
|
|
values_nnz_stride,
|
|
values_row_block_stride,
|
|
values_col_block_stride,
|
|
# values epilogue
|
|
# crow_indices prologue
|
|
crow_indices_ptr,
|
|
crow_indices_batch_stride,
|
|
crow_indices_stride,
|
|
# crow_indices epilogue
|
|
# col_indices prologue
|
|
col_indices_ptr,
|
|
col_indices_batch_stride,
|
|
col_indices_stride,
|
|
# col_indices epilogue
|
|
# dense prologue
|
|
dense_ptr,
|
|
dense_batch_stride,
|
|
dense_tiled_row_stride,
|
|
dense_tiled_col_stride,
|
|
dense_row_block_stride,
|
|
dense_col_block_stride,
|
|
# dense epilogue
|
|
# output prologue
|
|
output_ptr,
|
|
output_batch_stride,
|
|
output_tiled_row_stride,
|
|
output_tiled_col_stride,
|
|
output_row_block_stride,
|
|
output_col_block_stride,
|
|
# output epilogue
|
|
acc_dtype: tl.constexpr,
|
|
allow_tf32: tl.constexpr,
|
|
GROUP_SIZE_ROW: tl.constexpr,
|
|
):
|
|
batch_pid = tl.program_id(axis=2)
|
|
row_block_pid = tl.program_id(axis=0)
|
|
col_block_pid = tl.program_id(axis=1)
|
|
n_block_rows = tl.num_programs(axis=0)
|
|
n_block_cols = tl.num_programs(axis=1)
|
|
|
|
row_block_pid, col_block_pid = tl.swizzle2d(
|
|
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
|
|
)
|
|
|
|
crow_indices_offset_ptr = (
|
|
crow_indices_ptr
|
|
+ crow_indices_batch_stride * batch_pid
|
|
+ crow_indices_stride * row_block_pid
|
|
)
|
|
nnz_offset = tl.load(crow_indices_offset_ptr)
|
|
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
|
|
|
|
# Compute nnz for the row with number row_block_pid.
|
|
# If it is zero, skip the row.
|
|
row_nnz = nnz_offset_next - nnz_offset
|
|
if row_nnz == 0:
|
|
return
|
|
|
|
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
|
|
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
|
|
|
|
# Pointers are set to the first block of the current row.
|
|
values_block_ptrs = (
|
|
values_ptr
|
|
+ values_batch_stride * batch_pid
|
|
+ values_nnz_stride * nnz_offset
|
|
+ values_row_block_stride * row_block_arange[:, None]
|
|
+ values_col_block_stride * col_block_arange[None, :]
|
|
)
|
|
|
|
# NOTE: dense is advanced into all dimensions but the tiled row one.
|
|
# That will be advanced in the loop according to values in col_indices.
|
|
dense_block_ptrs = (
|
|
dense_ptr
|
|
+ dense_batch_stride * batch_pid
|
|
+ dense_tiled_col_stride * col_block_pid
|
|
+ dense_row_block_stride * col_block_arange[:, None]
|
|
+ dense_col_block_stride * row_block_arange[None, :]
|
|
)
|
|
|
|
# Pointers are set to exact write-to locations
|
|
output_ptrs = (
|
|
output_ptr
|
|
+ output_batch_stride * batch_pid
|
|
+ output_tiled_row_stride * row_block_pid
|
|
+ output_tiled_col_stride * col_block_pid
|
|
+ output_row_block_stride * row_block_arange[:, None]
|
|
+ output_col_block_stride * row_block_arange[None, :]
|
|
)
|
|
|
|
# Set pointer to the first nonzero element in the current row
|
|
col_index_nnz_ptr = (
|
|
col_indices_ptr
|
|
+ col_indices_batch_stride * batch_pid
|
|
+ col_indices_stride * nnz_offset
|
|
)
|
|
|
|
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)
|
|
for _ in range(row_nnz):
|
|
values_block = tl.load(values_block_ptrs)
|
|
|
|
# find which row of dense needs to get loaded
|
|
# for multiplication with values_block.
|
|
dense_row_idx = tl.load(col_index_nnz_ptr)
|
|
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
|
|
|
|
# do block mm
|
|
output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)
|
|
|
|
# move val/col_index ptrs to the next block in the row
|
|
values_block_ptrs += values_nnz_stride
|
|
col_index_nnz_ptr += col_indices_stride
|
|
|
|
# write back the result
|
|
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
|
|
|
|
|
|
def _run_dense_rowspace_kernel(
|
|
blocksize, values, crow_indices, col_indices, dense, output, max_grid
|
|
):
|
|
n_batches = dense.size(0)
|
|
n_block_rows = crow_indices.size(-1) - 1
|
|
n_block_cols = dense.size(-3)
|
|
|
|
full_grid = (n_batches, n_block_cols, n_block_rows)
|
|
if max_grid is not None:
|
|
grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))
|
|
else:
|
|
grid_blocks = None
|
|
tensor_dims_map = {
|
|
values: (0, None, None),
|
|
crow_indices: (0, None, -1),
|
|
col_indices: (0, None, None),
|
|
dense: (0, -3, None),
|
|
output: (0, -3, -4)
|
|
}
|
|
if values.dtype in (torch.half, torch.bfloat16):
|
|
acc_dtype = tl.float32
|
|
allow_tf32 = True
|
|
else:
|
|
acc_dtype = tl.float64
|
|
allow_tf32 = False
|
|
|
|
def kernel(grid, *sliced_tensors):
|
|
_bsr_strided_dense_rowspace_kernel[grid](
|
|
*blocksize,
|
|
*ptr_stride_extractor(*sliced_tensors),
|
|
acc_dtype=acc_dtype,
|
|
allow_tf32=allow_tf32,
|
|
GROUP_SIZE_ROW=4,
|
|
num_stages=1,
|
|
num_warps=4
|
|
)
|
|
|
|
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
|
|
|
|
|
|
def _run_sampled_addmm_kernel(
|
|
alpha, beta, is_beta_zero,
|
|
blocksize, k, tile_k,
|
|
values, crow_indices, col_indices,
|
|
mat1, mat2,
|
|
max_grid
|
|
):
|
|
n_batches = values.size(0)
|
|
n_block_rows = crow_indices.size(-1) - 1
|
|
|
|
full_grid = (n_batches, n_block_rows)
|
|
if max_grid is not None:
|
|
grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
|
|
else:
|
|
grid_blocks = None
|
|
tensor_dims_map = {
|
|
values: (0, None),
|
|
crow_indices: (0, -1),
|
|
col_indices: (0, None),
|
|
mat1: (0, -4),
|
|
mat2: (0, None),
|
|
}
|
|
if values.dtype in (torch.half, torch.bfloat16):
|
|
acc_dtype = tl.float32
|
|
allow_tf32 = True
|
|
else:
|
|
acc_dtype = tl.float64
|
|
allow_tf32 = False
|
|
|
|
def kernel(grid, *sliced_tensors):
|
|
_sampled_addmm_kernel[grid](
|
|
alpha, beta, is_beta_zero,
|
|
*blocksize, k, tile_k,
|
|
*ptr_stride_extractor(*sliced_tensors),
|
|
acc_dtype=acc_dtype,
|
|
allow_tf32=allow_tf32,
|
|
num_stages=1,
|
|
num_warps=4
|
|
)
|
|
|
|
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
|
|
|
|
|
|
def sampled_addmm(
|
|
input: torch.Tensor,
|
|
mat1: torch.Tensor,
|
|
mat2: torch.Tensor,
|
|
*,
|
|
beta=1.0,
|
|
alpha=1.0,
|
|
out: Optional[torch.Tensor] = None,
|
|
skip_checks: bool = False,
|
|
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
|
):
|
|
f_name = "sampled_addmm"
|
|
|
|
check_bsr_layout(f_name, input)
|
|
input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)
|
|
|
|
if not skip_checks:
|
|
check_device(f_name, mat1, input.device)
|
|
check_device(f_name, mat2, input.device)
|
|
if beta != 0.0 and input.dtype is torch.bool:
|
|
check(
|
|
False,
|
|
f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed."
|
|
)
|
|
if input.dtype is not torch.bool:
|
|
check_dtype(f_name, mat1, input.dtype)
|
|
check_dtype(f_name, mat2, input.dtype)
|
|
else:
|
|
check_dtype(f_name, mat1, mat2.dtype)
|
|
check_mm_compatible_shapes(f_name, mat1, mat2)
|
|
if out is not None:
|
|
check_bsr_layout(f_name, out)
|
|
check_device(f_name, out, mat1.device)
|
|
check_dtype(f_name, out, input.dtype)
|
|
check(
|
|
out.shape == input_broadcasted.shape
|
|
and out._nnz() == input._nnz(),
|
|
f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} "
|
|
f"and with nnz equal to {input_broadcasted._nnz()} "
|
|
f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}"
|
|
)
|
|
|
|
if out is None:
|
|
out = input_broadcasted.to(mat1.dtype, copy=True)
|
|
else:
|
|
out.copy_(input_broadcasted)
|
|
|
|
if out.numel() == 0 or out._nnz() == 0:
|
|
return out
|
|
|
|
blocksize = out.values().shape[-2:]
|
|
m = mat1.size(-2)
|
|
n = mat2.size(-1)
|
|
k = mat1.size(-1)
|
|
|
|
# NOTE: (m, 0) @ (0, n) == zeros(m, n)
|
|
if alpha == 0.0 or k == 0:
|
|
out.values().mul_(beta)
|
|
return out
|
|
|
|
# prepare inputs by reshaping them to be kernel-compatible
|
|
out_backup = out
|
|
crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)
|
|
|
|
mat1 = tile_to_blocksize(mat1, (blocksize[0], k))
|
|
mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))
|
|
tile_k = max(*blocksize)
|
|
|
|
_run_sampled_addmm_kernel(
|
|
alpha, beta, beta == 0.0,
|
|
blocksize, k, tile_k,
|
|
values, crow_indices, col_indices,
|
|
mat1, mat2,
|
|
max_grid
|
|
)
|
|
|
|
# If nnz x block strides are not the same in out_backup.values and values,
|
|
# it means that out_backup.values and values are not the views of each other,
|
|
# so we have to copy.
|
|
if out_backup.values().stride()[-3:] != values.stride()[-3:]:
|
|
out_backup.values().copy_(values.reshape(out_backup.values().shape))
|
|
return out_backup
|
|
|
|
|
|
def bsr_dense_mm(
|
|
bsr: torch.Tensor,
|
|
dense: torch.Tensor,
|
|
*,
|
|
out: Optional[torch.Tensor] = None,
|
|
skip_checks: bool = False,
|
|
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
|
):
|
|
f_name = "bsr_dense_mm"
|
|
if not skip_checks:
|
|
check_bsr_layout(f_name, bsr)
|
|
check_device(f_name, bsr, dense.device)
|
|
check_dtype(f_name, bsr, dense.dtype)
|
|
check_mm_compatible_shapes(f_name, bsr, dense)
|
|
|
|
m = bsr.size(-2)
|
|
n = dense.size(-1)
|
|
row_block, col_block = bsr.values().shape[-2:]
|
|
check(
|
|
not n % row_block,
|
|
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
|
|
f"blocksize[0] == {row_block}.",
|
|
)
|
|
check_blocksize(f_name, (row_block, col_block))
|
|
else:
|
|
m, kl = bsr.shape[-2:]
|
|
kr, n = dense.shape[-2:]
|
|
|
|
original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
|
|
|
|
if out is not None and not skip_checks:
|
|
expected_out_shape = original_batch_dims_broadcasted + (m, n)
|
|
check(
|
|
out.shape == expected_out_shape,
|
|
"bsr_dense_mm(): `out` argument has wrong shape, "
|
|
f"expected {expected_out_shape}, but got {out.shape}.",
|
|
)
|
|
check(
|
|
out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
|
|
"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
|
|
"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
|
|
"should be True.",
|
|
)
|
|
|
|
# Allocate out
|
|
if out is None:
|
|
out = dense.new_empty(original_batch_dims_broadcasted + (m, n))
|
|
|
|
# Short circuit if lhs is zero
|
|
if bsr._nnz() == 0:
|
|
return out.zero_()
|
|
|
|
blocksize = bsr.values().shape[-2:]
|
|
|
|
# NOTE: out is contiguous, so prepare_inputs will create a view.
|
|
# out gets modified in-place, so we store a backup copy.
|
|
out_backup = out
|
|
|
|
# prepare inputs by reshaping them to be kernel-compatible.
|
|
crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out)
|
|
|
|
# "Blockify" the row dimension of dense with blocksize[1]
|
|
# since dense is on the rhs of matmul
|
|
dense = tile_to_blocksize(dense, blocksize[::-1])
|
|
# "Blockify" the row dimension of out with blocksize[0]
|
|
# which is inherited from the bsr input.
|
|
# NOTE: tile_to_blocksize will create a view.
|
|
# NOTE: out.blocksize[-1] == dense.blocksize[-1],
|
|
# so it could be any value in [1, dense.shape[-1]).
|
|
# We need to probably use the largest possible blocksize
|
|
# so that it fits into SRAM.
|
|
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
|
|
|
|
# Launch kernel
|
|
_run_dense_rowspace_kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
|
|
|
|
return out_backup
|
|
|
|
|
|
@triton.jit
|
|
def _bsr_softmax_kernel(
|
|
crow_indices_ptr,
|
|
crow_indices_batch_stride,
|
|
crow_indices_stride,
|
|
values_ptr,
|
|
values_batch_stride,
|
|
values_row_block_stride,
|
|
values_nnz_col_block_stride,
|
|
row_block, col_block,
|
|
MAX_ROW_NNZ: tl.constexpr,
|
|
TILE: tl.constexpr
|
|
):
|
|
batch_pid = tl.program_id(axis=2)
|
|
row_block_offset_pid = tl.program_id(axis=1)
|
|
row_block_pid = tl.program_id(axis=0)
|
|
|
|
crow_indices_offset_ptr = (
|
|
crow_indices_ptr
|
|
+ crow_indices_batch_stride * batch_pid
|
|
+ crow_indices_stride * row_block_pid
|
|
)
|
|
nnz_offset = tl.load(crow_indices_offset_ptr)
|
|
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
|
|
|
|
# Compute nnz for the row with number row_block_pid.
|
|
# If it is zero, skip the row.
|
|
row_nnz = nnz_offset_next - nnz_offset
|
|
if row_nnz == 0:
|
|
return
|
|
|
|
row_arange = tl.arange(0, TILE)
|
|
mask = row_arange < row_nnz * col_block
|
|
|
|
curr_row_values_ptrs = (
|
|
values_ptr
|
|
+ values_batch_stride * batch_pid
|
|
+ values_row_block_stride * row_block_offset_pid
|
|
+ nnz_offset * col_block
|
|
)
|
|
|
|
# find max in the row
|
|
row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)
|
|
max_row_value = tl.max(row_tile, axis=0)
|
|
for _ in range(TILE, MAX_ROW_NNZ, TILE):
|
|
row_arange += TILE
|
|
mask = row_arange < row_nnz * col_block
|
|
row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)
|
|
curr_max_row_value = tl.max(row_tile, axis=0)
|
|
max_row_value = tl.where(max_row_value > curr_max_row_value, max_row_value, curr_max_row_value)
|
|
|
|
# find denominator for stable softmax
|
|
num = tl.exp(row_tile - max_row_value)
|
|
denom = tl.sum(num, axis=0)
|
|
for _ in range(TILE, MAX_ROW_NNZ, TILE):
|
|
row_arange -= TILE
|
|
mask = row_arange < row_nnz * col_block
|
|
row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)
|
|
num = tl.exp(row_tile - max_row_value)
|
|
denom += tl.sum(num, axis=0)
|
|
|
|
# populate output
|
|
tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)
|
|
for _ in range(TILE, MAX_ROW_NNZ, TILE):
|
|
row_arange += TILE
|
|
mask = row_arange < row_nnz * col_block
|
|
row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)
|
|
num = tl.exp(row_tile - max_row_value)
|
|
tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)
|
|
|
|
|
|
def bsr_softmax(input, max_row_nnz=None):
|
|
f_name = "bsr_softmax"
|
|
|
|
check_bsr_layout(f_name, input)
|
|
check_dtype(f_name, input, input.dtype)
|
|
|
|
if input._nnz() == 0 or input.numel() == 0:
|
|
return input.clone()
|
|
|
|
m, n = input.shape[-2:]
|
|
nnz = input._nnz()
|
|
row_block, col_block = input.values().shape[-2:]
|
|
|
|
if max_row_nnz is None:
|
|
max_row_nnz = triton.next_power_of_2(n)
|
|
else:
|
|
max_row_nnz = triton.next_power_of_2(max_row_nnz)
|
|
|
|
crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)
|
|
# reshape values from
|
|
# (b1, ..., bn, nnz, row_block, col_block) to
|
|
# (b1 * ... * bn, row_block, nnz * col_block).
|
|
# This simplifies batch dim manipulation and unlocks
|
|
# the possibility to access all nnzs in any given row.
|
|
if input.values().transpose(-3, -2).is_contiguous():
|
|
# Need to clone to avoid `contiguous` returning a view.
|
|
values = input.values().clone()
|
|
else:
|
|
values = input.values()
|
|
values = values.transpose(-3, -2).contiguous().unsqueeze(0).flatten(0, -4).reshape(-1, row_block, nnz * col_block)
|
|
full_grid = (values.shape[0], row_block, m // row_block)
|
|
grid_blocks = None
|
|
tensor_dims_map = {
|
|
# We span nnz number of blocks, not nnz + 1,
|
|
# hence crow_indices[..., :-1]
|
|
crow_indices[..., :-1]: (0, None, -1),
|
|
values: (0, None, None),
|
|
}
|
|
|
|
def kernel(grid, *sliced_tensors):
|
|
_bsr_softmax_kernel[grid](
|
|
*ptr_stride_extractor(*sliced_tensors),
|
|
row_block, col_block,
|
|
max_row_nnz,
|
|
# Triton's max numel is bounded by 2 ** 17.
|
|
min(2 ** 17, max_row_nnz)
|
|
)
|
|
|
|
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
|
|
|
|
values = values.reshape(-1, row_block, nnz, col_block).transpose(-3, -2).reshape(*input.values().shape)
|
|
|
|
return torch.sparse_compressed_tensor(
|
|
input.crow_indices().clone(),
|
|
input.col_indices().clone(),
|
|
values,
|
|
size=input.shape,
|
|
layout=input.layout
|
|
)
|
|
|
|
def _scaled_dot_product_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor],
|
|
dropout_p: float = 0.0,
|
|
is_causal: bool = False,
|
|
scale: Optional[float] = None
|
|
):
|
|
f_name = "_scaled_dot_product_attention"
|
|
check(
|
|
not is_causal,
|
|
f"{f_name}(): is_causal == True is not supported."
|
|
)
|
|
check(
|
|
attn_mask is not None,
|
|
f"{f_name}(): attn_mask == None is not supported."
|
|
)
|
|
assert attn_mask is not None
|
|
|
|
check(
|
|
attn_mask.layout == torch.sparse_bsr,
|
|
f"{f_name}(): "
|
|
f"attn_mask.layout must be {torch.sparse_bsr}, but got "
|
|
f"attn_mask.layout == {attn_mask.layout}."
|
|
)
|
|
|
|
check_device(f_name, key, query.device)
|
|
check_device(f_name, value, query.device)
|
|
check_device(f_name, attn_mask, query.device)
|
|
|
|
check_dtype(f_name, key, query.dtype)
|
|
check_dtype(f_name, value, query.dtype)
|
|
if attn_mask.dtype is not torch.bool:
|
|
check_dtype(f_name, attn_mask, query.dtype)
|
|
|
|
sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)
|
|
if scale is None and query.size(-1) == 0 or scale == 0.0:
|
|
check(
|
|
False,
|
|
f"{f_name}(): current value of scale == {scale} "
|
|
"results in division by zero."
|
|
)
|
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
sdpa.values().mul_(scale_factor)
|
|
sdpa = bsr_softmax(sdpa)
|
|
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
|
sdpa = bsr_dense_mm(sdpa, value)
|
|
return sdpa
|
|
else:
|
|
bsr_softmax = None # type: ignore[assignment]
|
|
bsr_dense_mm = None # type: ignore[assignment]
|
|
sampled_addmm = None # type: ignore[assignment]
|
|
_scaled_dot_product_attention = None # type: ignore[assignment]
|