From 034e951b0cfb02d7b55327cd482e58cf2695dca0 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Fri, 31 Oct 2025 10:18:28 +0000 Subject: [PATCH] [CUDA][cuBLASLt] addmm -- extend bias fusions to cases with (1 by n) shapes (#166307) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166307 Approved by: https://github.com/eqy --- aten/src/ATen/native/cuda/Blas.cpp | 10 +++++--- test/test_linalg.py | 7 ++++-- test/test_matmul_cuda.py | 39 ++++++++++++++++++++++++++---- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index f29be23acd5..9c8a3c708ed 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const #if defined(CUDA_VERSION) || defined(USE_ROCM) const auto scalar_type = mat1.scalar_type(); return (beta.toComplexDouble() == 1.0 - // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] - // is to use lt interface only when self is bias. - && self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous() && result.dim() == 2 && result.is_contiguous() + // Conditions for bias to be fusable + && ( + self.is_contiguous() && + // NOTE: fine to have 1-len dims to the left from the right-most one + (self.dim() == 1 || self.squeeze().dim() == 1) && + self.sizes().back() == mat2_sizes[1] + ) && ( // some dtype restrictions #ifndef USE_ROCM scalar_type == at::ScalarType::Double || diff --git a/test/test_linalg.py b/test/test_linalg.py index 01a6dd5c8ec..032c5264196 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7328,9 +7328,11 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: m2 = torch.randn(50, 25, device=device).to(dtype) self._test_addmm_addmv(func, M, m1, m2, activation=activation) - # vector-shaped bias and beta=1 result in epilogue fusion in CUDA + # vector-shaped bias (or with 1-len dims on the left from the leading dim) + # and beta=1 result in epilogue fusion in CUDA V = torch.randn(25, device=device).to(dtype) self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) + self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation) # Test 0-strided M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) @@ -7357,8 +7359,9 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) if t1: - # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) + # use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) + self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,) @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index c5ae0dd7242..5e54a851812 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,6 +5,7 @@ import time import unittest from itertools import product from functools import partial +from typing import Callable import torch @@ -90,14 +91,21 @@ class TestMatmulCuda(InductorTestCase): torch.backends.cuda.matmul.allow_tf32 = True super().tearDown() - def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): + def cublas_addmm( + self, + size: int, + dtype: torch.dtype, + reduced_precision: bool = False, + fp16_accumulate: bool = False, + bias_shape_modifier: Callable | None = None, + ): # # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # results from the CUDA invocation of torch.addmm and the CPU invocation # (which does not use CUDA backend). # # Get dims - n, m, p = (size + 1, size, size + 2) + m, k, n = (size + 1, size, size + 2) # Disable reduced precision reductions in BFloat16 to bypass some kernels # which fail the threshold check orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction @@ -109,10 +117,12 @@ class TestMatmulCuda(InductorTestCase): # Make random tensors on CPU (seed set on common_utils.py import) # (Not using numpy because it does not support bfloat16) make_arg = partial(make_tensor, dtype=dtype, device="cpu") + + bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier + m_input = make_arg(bias_shape_modifier((m, n))) + m_1 = make_arg((m, k)) + m_2 = make_arg((k, n)) m_beta = make_arg(1) - m_input = make_arg((n, p)) - m_1 = make_arg((n, m)) - m_2 = make_arg((m, p)) # scale to abate overflows in fp16 accum if fp16_accumulate: m_1 = m_1 / 100 @@ -179,6 +189,25 @@ class TestMatmulCuda(InductorTestCase): with blas_library_context(backend): self.cublas_addmm(size, dtype, True) + + @onlyCUDA + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4), + torch.bfloat16: xtol(atol=1e-3, rtol=1e-4), + torch.float32: xtol(atol=1e-3, rtol=1e-4)}) + @dtypes(torch.bfloat16, torch.float16, torch.float32) + @parametrize("size", [128]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + # 2D bias + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape) + # 1D bias which is row-broadcast to 2D + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1])) + # 1D bias which row-broadcasts + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],)) + + @onlyCUDA @dtypes(torch.float16) # m == 4 chooses OUTPUT_TYPE reduction on H200