mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] [TunableOp] Unit tests for scaled GEMM and GEMM with bias (#147890)
Two more unit tests for TunableOp: - Scaled GEMM - GEMM with bias Pull Request resolved: https://github.com/pytorch/pytorch/pull/147890 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
b13ad1a193
commit
84e60eece8
|
|
@ -20,7 +20,8 @@ from torch.testing._internal.common_utils import \
|
|||
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
|
||||
make_fullrank_matrices_with_distinct_singular_values,
|
||||
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
|
||||
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
|
||||
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest,
|
||||
runOnRocmArch, MI300_ARCH)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
|
||||
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
|
||||
|
|
@ -5149,6 +5150,105 @@ class TestLinalg(TestCase):
|
|||
# Clean up, remove file that was generated
|
||||
os.remove(filename)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_gemm_bias_tunableop(self, device, dtype):
|
||||
# Test GEMM and bias tuning
|
||||
set_tunableop_defaults()
|
||||
torch.cuda.tunable.enable()
|
||||
# set these to single iterations to keep it short but still exercise the code
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
|
||||
# Reference number of results
|
||||
ref_num_results = len(torch.cuda.tunable.get_results())
|
||||
|
||||
m = 3
|
||||
n = 5
|
||||
k = 7
|
||||
X = torch.rand(m, k, dtype=dtype, device=device)
|
||||
matA = torch.rand(n, k, dtype=dtype, device=device)
|
||||
bias = torch.rand(n, dtype=dtype, device=device)
|
||||
|
||||
torch.nn.functional.linear(X, matA, bias)
|
||||
|
||||
# This stores total number of cummulative results
|
||||
total_num_results = len(torch.cuda.tunable.get_results())
|
||||
|
||||
# There must be a new tuning result
|
||||
self.assertEqual((total_num_results - ref_num_results), 1)
|
||||
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
# clean up, remove any file that was generated
|
||||
try:
|
||||
import os
|
||||
filename = torch.cuda.tunable.get_filename()
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
|
||||
def test_scaled_gemm_tunableop(self, device, dtype):
|
||||
# Test Scaled GEMM tuning.
|
||||
# We do not test the full set of scaled GEMM parameters, since
|
||||
# hipBLASLt does not support all combinations.
|
||||
# Here is a short list of extra parameters that are not tested
|
||||
# - amax
|
||||
# - use_fast_accum
|
||||
# - bias dtype that are different than torch.half
|
||||
#
|
||||
# Refer to test/test_matmul_cuda for support combinations that are
|
||||
# tested by PyTorch
|
||||
|
||||
set_tunableop_defaults()
|
||||
torch.cuda.tunable.enable()
|
||||
# set these to single iterations to keep it short but still exercise the code
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
|
||||
# Reference number of results
|
||||
ref_num_results = len(torch.cuda.tunable.get_results())
|
||||
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
m = n = k = 16
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
dtypeA = dtypeB = dtype
|
||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||
|
||||
# out_dtype = dtype
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtype)
|
||||
# out_dtype = float32
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.float32)
|
||||
# out_dtype = bfloat16
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||
# out_dtype = float16
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.half)
|
||||
|
||||
# This stores total number of cummulative results
|
||||
total_num_results = len(torch.cuda.tunable.get_results())
|
||||
|
||||
# There must be a four new tuning results
|
||||
self.assertEqual((total_num_results - ref_num_results), 4)
|
||||
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
# clean up, remove any file that was generated
|
||||
try:
|
||||
import os
|
||||
filename = torch.cuda.tunable.get_filename()
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user