[ROCm] [TunableOp] Unit tests for scaled GEMM and GEMM with bias (#147890)

Two more unit tests for TunableOp:
- Scaled GEMM
- GEMM with bias

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147890
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero 2025-02-26 22:41:21 +00:00 committed by PyTorch MergeBot
parent b13ad1a193
commit 84e60eece8

View File

@ -20,7 +20,8 @@ from torch.testing._internal.common_utils import \
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
make_fullrank_matrices_with_distinct_singular_values,
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest,
runOnRocmArch, MI300_ARCH)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
@ -5149,6 +5150,105 @@ class TestLinalg(TestCase):
# Clean up, remove file that was generated
os.remove(filename)
@onlyCUDA
@dtypes(torch.bfloat16)
def test_gemm_bias_tunableop(self, device, dtype):
# Test GEMM and bias tuning
set_tunableop_defaults()
torch.cuda.tunable.enable()
# set these to single iterations to keep it short but still exercise the code
torch.cuda.tunable.set_max_tuning_iterations(1)
# Reference number of results
ref_num_results = len(torch.cuda.tunable.get_results())
m = 3
n = 5
k = 7
X = torch.rand(m, k, dtype=dtype, device=device)
matA = torch.rand(n, k, dtype=dtype, device=device)
bias = torch.rand(n, dtype=dtype, device=device)
torch.nn.functional.linear(X, matA, bias)
# This stores total number of cummulative results
total_num_results = len(torch.cuda.tunable.get_results())
# There must be a new tuning result
self.assertEqual((total_num_results - ref_num_results), 1)
# disable TunableOp
torch.cuda.tunable.enable(False)
# clean up, remove any file that was generated
try:
import os
filename = torch.cuda.tunable.get_filename()
os.remove(filename)
except FileNotFoundError:
pass
@onlyCUDA
@skipCUDAIfNotRocm
@runOnRocmArch(MI300_ARCH)
@dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
def test_scaled_gemm_tunableop(self, device, dtype):
# Test Scaled GEMM tuning.
# We do not test the full set of scaled GEMM parameters, since
# hipBLASLt does not support all combinations.
# Here is a short list of extra parameters that are not tested
# - amax
# - use_fast_accum
# - bias dtype that are different than torch.half
#
# Refer to test/test_matmul_cuda for support combinations that are
# tested by PyTorch
set_tunableop_defaults()
torch.cuda.tunable.enable()
# set these to single iterations to keep it short but still exercise the code
torch.cuda.tunable.set_max_tuning_iterations(1)
# Reference number of results
ref_num_results = len(torch.cuda.tunable.get_results())
# Scaled GEMM parameters
fillA = 0.25
fillB = 0.75
m = n = k = 16
scaleA = torch.tensor(0.8, device=device)
scaleB = torch.tensor(0.9, device=device)
dtypeA = dtypeB = dtype
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
# out_dtype = dtype
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtype)
# out_dtype = float32
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.float32)
# out_dtype = bfloat16
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
# out_dtype = float16
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.half)
# This stores total number of cummulative results
total_num_results = len(torch.cuda.tunable.get_results())
# There must be a four new tuning results
self.assertEqual((total_num_results - ref_num_results), 4)
# disable TunableOp
torch.cuda.tunable.enable(False)
# clean up, remove any file that was generated
try:
import os
filename = torch.cuda.tunable.get_filename()
os.remove(filename)
except FileNotFoundError:
pass
@dtypes(torch.float, torch.complex64)
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)