mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA][cuBLASLt] addmm -- extend bias fusions to cases with (1 by n) shapes (#166307)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166307 Approved by: https://github.com/eqy
This commit is contained in:
parent
160ab53dd5
commit
034e951b0c
|
|
@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||||
const auto scalar_type = mat1.scalar_type();
|
const auto scalar_type = mat1.scalar_type();
|
||||||
return (beta.toComplexDouble() == 1.0
|
return (beta.toComplexDouble() == 1.0
|
||||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
|
||||||
// is to use lt interface only when self is bias.
|
|
||||||
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
|
|
||||||
&& result.dim() == 2 && result.is_contiguous()
|
&& result.dim() == 2 && result.is_contiguous()
|
||||||
|
// Conditions for bias to be fusable
|
||||||
|
&& (
|
||||||
|
self.is_contiguous() &&
|
||||||
|
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||||
|
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||||
|
self.sizes().back() == mat2_sizes[1]
|
||||||
|
)
|
||||||
&& ( // some dtype restrictions
|
&& ( // some dtype restrictions
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
scalar_type == at::ScalarType::Double ||
|
scalar_type == at::ScalarType::Double ||
|
||||||
|
|
|
||||||
|
|
@ -7328,9 +7328,11 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||||
m2 = torch.randn(50, 25, device=device).to(dtype)
|
m2 = torch.randn(50, 25, device=device).to(dtype)
|
||||||
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
|
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
|
||||||
|
|
||||||
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
|
# vector-shaped bias (or with 1-len dims on the left from the leading dim)
|
||||||
|
# and beta=1 result in epilogue fusion in CUDA
|
||||||
V = torch.randn(25, device=device).to(dtype)
|
V = torch.randn(25, device=device).to(dtype)
|
||||||
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
|
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
|
||||||
|
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation)
|
||||||
|
|
||||||
# Test 0-strided
|
# Test 0-strided
|
||||||
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
|
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
|
||||||
|
|
@ -7357,8 +7359,9 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||||
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
|
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
|
||||||
|
|
||||||
if t1:
|
if t1:
|
||||||
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
|
# use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
|
||||||
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
|
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
|
||||||
|
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,)
|
||||||
|
|
||||||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
|
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
|
||||||
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import time
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -90,14 +91,21 @@ class TestMatmulCuda(InductorTestCase):
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
|
def cublas_addmm(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
reduced_precision: bool = False,
|
||||||
|
fp16_accumulate: bool = False,
|
||||||
|
bias_shape_modifier: Callable | None = None,
|
||||||
|
):
|
||||||
#
|
#
|
||||||
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
|
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
|
||||||
# results from the CUDA invocation of torch.addmm and the CPU invocation
|
# results from the CUDA invocation of torch.addmm and the CPU invocation
|
||||||
# (which does not use CUDA backend).
|
# (which does not use CUDA backend).
|
||||||
#
|
#
|
||||||
# Get dims
|
# Get dims
|
||||||
n, m, p = (size + 1, size, size + 2)
|
m, k, n = (size + 1, size, size + 2)
|
||||||
# Disable reduced precision reductions in BFloat16 to bypass some kernels
|
# Disable reduced precision reductions in BFloat16 to bypass some kernels
|
||||||
# which fail the threshold check
|
# which fail the threshold check
|
||||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||||
|
|
@ -109,10 +117,12 @@ class TestMatmulCuda(InductorTestCase):
|
||||||
# Make random tensors on CPU (seed set on common_utils.py import)
|
# Make random tensors on CPU (seed set on common_utils.py import)
|
||||||
# (Not using numpy because it does not support bfloat16)
|
# (Not using numpy because it does not support bfloat16)
|
||||||
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
|
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
|
||||||
|
|
||||||
|
bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier
|
||||||
|
m_input = make_arg(bias_shape_modifier((m, n)))
|
||||||
|
m_1 = make_arg((m, k))
|
||||||
|
m_2 = make_arg((k, n))
|
||||||
m_beta = make_arg(1)
|
m_beta = make_arg(1)
|
||||||
m_input = make_arg((n, p))
|
|
||||||
m_1 = make_arg((n, m))
|
|
||||||
m_2 = make_arg((m, p))
|
|
||||||
# scale to abate overflows in fp16 accum
|
# scale to abate overflows in fp16 accum
|
||||||
if fp16_accumulate:
|
if fp16_accumulate:
|
||||||
m_1 = m_1 / 100
|
m_1 = m_1 / 100
|
||||||
|
|
@ -179,6 +189,25 @@ class TestMatmulCuda(InductorTestCase):
|
||||||
with blas_library_context(backend):
|
with blas_library_context(backend):
|
||||||
self.cublas_addmm(size, dtype, True)
|
self.cublas_addmm(size, dtype, True)
|
||||||
|
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||||
|
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4),
|
||||||
|
torch.bfloat16: xtol(atol=1e-3, rtol=1e-4),
|
||||||
|
torch.float32: xtol(atol=1e-3, rtol=1e-4)})
|
||||||
|
@dtypes(torch.bfloat16, torch.float16, torch.float32)
|
||||||
|
@parametrize("size", [128])
|
||||||
|
@parametrize("backend", ["cublas", "cublaslt"])
|
||||||
|
def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend):
|
||||||
|
with blas_library_context(backend):
|
||||||
|
# 2D bias
|
||||||
|
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape)
|
||||||
|
# 1D bias which is row-broadcast to 2D
|
||||||
|
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1]))
|
||||||
|
# 1D bias which row-broadcasts
|
||||||
|
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],))
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@dtypes(torch.float16)
|
@dtypes(torch.float16)
|
||||||
# m == 4 chooses OUTPUT_TYPE reduction on H200
|
# m == 4 chooses OUTPUT_TYPE reduction on H200
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user