[CUDA][cuBLASLt] addmm -- extend bias fusions to cases with (1 by n) shapes (#166307)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166307
Approved by: https://github.com/eqy
This commit is contained in:
Nikita Vedeneev 2025-10-31 10:18:28 +00:00 committed by PyTorch MergeBot
parent 160ab53dd5
commit 034e951b0c
3 changed files with 46 additions and 10 deletions

View File

@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
#if defined(CUDA_VERSION) || defined(USE_ROCM) #if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type(); const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0 return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous() && result.dim() == 2 && result.is_contiguous()
// Conditions for bias to be fusable
&& (
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
&& ( // some dtype restrictions && ( // some dtype restrictions
#ifndef USE_ROCM #ifndef USE_ROCM
scalar_type == at::ScalarType::Double || scalar_type == at::ScalarType::Double ||

View File

@ -7328,9 +7328,11 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
m2 = torch.randn(50, 25, device=device).to(dtype) m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(func, M, m1, m2, activation=activation) self._test_addmm_addmv(func, M, m1, m2, activation=activation)
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA # vector-shaped bias (or with 1-len dims on the left from the leading dim)
# and beta=1 result in epilogue fusion in CUDA
V = torch.randn(25, device=device).to(dtype) V = torch.randn(25, device=device).to(dtype)
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation)
# Test 0-strided # Test 0-strided
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
@ -7357,8 +7359,9 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
if t1: if t1:
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) # use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,)
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})

View File

@ -5,6 +5,7 @@ import time
import unittest import unittest
from itertools import product from itertools import product
from functools import partial from functools import partial
from typing import Callable
import torch import torch
@ -90,14 +91,21 @@ class TestMatmulCuda(InductorTestCase):
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
super().tearDown() super().tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): def cublas_addmm(
self,
size: int,
dtype: torch.dtype,
reduced_precision: bool = False,
fp16_accumulate: bool = False,
bias_shape_modifier: Callable | None = None,
):
# #
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation # results from the CUDA invocation of torch.addmm and the CPU invocation
# (which does not use CUDA backend). # (which does not use CUDA backend).
# #
# Get dims # Get dims
n, m, p = (size + 1, size, size + 2) m, k, n = (size + 1, size, size + 2)
# Disable reduced precision reductions in BFloat16 to bypass some kernels # Disable reduced precision reductions in BFloat16 to bypass some kernels
# which fail the threshold check # which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
@ -109,10 +117,12 @@ class TestMatmulCuda(InductorTestCase):
# Make random tensors on CPU (seed set on common_utils.py import) # Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16) # (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu") make_arg = partial(make_tensor, dtype=dtype, device="cpu")
bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier
m_input = make_arg(bias_shape_modifier((m, n)))
m_1 = make_arg((m, k))
m_2 = make_arg((k, n))
m_beta = make_arg(1) m_beta = make_arg(1)
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# scale to abate overflows in fp16 accum # scale to abate overflows in fp16 accum
if fp16_accumulate: if fp16_accumulate:
m_1 = m_1 / 100 m_1 = m_1 / 100
@ -179,6 +189,25 @@ class TestMatmulCuda(InductorTestCase):
with blas_library_context(backend): with blas_library_context(backend):
self.cublas_addmm(size, dtype, True) self.cublas_addmm(size, dtype, True)
@onlyCUDA
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4),
torch.bfloat16: xtol(atol=1e-3, rtol=1e-4),
torch.float32: xtol(atol=1e-3, rtol=1e-4)})
@dtypes(torch.bfloat16, torch.float16, torch.float32)
@parametrize("size", [128])
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
# 2D bias
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape)
# 1D bias which is row-broadcast to 2D
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1]))
# 1D bias which row-broadcasts
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],))
@onlyCUDA @onlyCUDA
@dtypes(torch.float16) @dtypes(torch.float16)
# m == 4 chooses OUTPUT_TYPE reduction on H200 # m == 4 chooses OUTPUT_TYPE reduction on H200