[ROCm][TunableOp] Unit test for TunableOp BLAS logging. (#148982)

Add unit test for new TunableOp BLAS logging feature.

Requires this PR to be merged in first: https://github.com/pytorch/pytorch/pull/148979

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148982
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero 2025-03-19 20:57:15 +00:00 committed by PyTorch MergeBot
parent 71daeddde2
commit 37bb7f79c6

View File

@ -78,12 +78,17 @@ def set_tunableop_defaults():
torch.cuda.tunable.set_max_tuning_iterations(100)
torch.cuda.tunable.set_rotating_buffer_size(-1)
def tunableop_matmul(device, dtype):
def tunableop_matmul(device, dtype, offline=False):
# Helper function to test TunableOp in a subprocess
# requires helper function since lambda function
# not supported by multiprocessing module
import os
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
if offline:
torch.cuda.tunable.tuning_enable(False)
torch.cuda.tunable.record_untuned_enable(True)
torch.cuda.tunable.set_max_tuning_duration(1)
A = torch.randn((17, 17), device=device, dtype=dtype)
B = torch.randn((17, 17), device=device, dtype=dtype)
@ -5661,6 +5666,108 @@ class TestLinalg(TestCase):
except (FileNotFoundError, PermissionError):
pass
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float16)
def test_blaslog_tunableop(self, device, dtype):
# Test that PYTORCH_TUNABLEOP_BLAS_LOG=1 gives
# an additional column of data with the BLAS
# parameters in offline and online tuning.
#
# We record GEMMs and then check that the
# BLAS_PARAMS appear in
# tunableop_untuned CSV file
# and
# tunableop_results CSV file
#
# NOTE: This is done in a subproceses
# because in the main process
# PYTORCH_TUNABLEOP_BLAS_LOG has already
# been deactivated and its value is sticky
import os
import multiprocessing as mp
set_tunableop_defaults()
ordinal = torch.cuda.current_device()
result_filename = f"tunableop_results{ordinal}.csv"
untuned_filename = f"tunableop_untuned{ordinal}.csv"
# Test in try-finally block to avoid leaking state
# if test is interrupted.
try:
os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1")
# Offline Tuning case in a subprocess
# force=True needed according to:
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
# This is because a different test in this process could have
# already set the start method
mp.set_start_method("spawn", force=True)
p = mp.Process(target=tunableop_matmul, args=(device, dtype, True))
p.start()
p.join()
# Make sure the results file exists and that it is not zero
self.assertTrue(os.path.exists(untuned_filename))
self.assertTrue(os.path.getsize(untuned_filename) > 0)
# Check that the BLAS PARAMS are in the CSV file
import csv
with open(untuned_filename) as file:
reader = csv.reader(file)
first_row = next(reader)
# Check for extra column
self.assertGreater(len(first_row), 3)
# Check for YAML entry to the right of
# BLAS PARAMS
self.assertTrue("{ function:" in first_row[2])
# Online tuning case in a subprocess
# force=True needed according to:
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
# This is because a different test in this process could have
# already set the start method
mp.set_start_method("spawn", force=True)
p = mp.Process(target=tunableop_matmul, args=(device, dtype, False))
p.start()
p.join()
# Make sure the results file exists and that it is not zero
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Check that there BLAS PARAMS are in the CSV file
with open(result_filename) as file:
reader = csv.reader(file)
for _ in range(5): # Skip the first 5 lines for the validator
next(reader, None)
# Check for extra column
first_row = next(reader)
self.assertGreater(len(first_row), 5)
# Check for YAML entry to the right of
# BLAS PARAMS
self.assertTrue("{ function:" in first_row[4])
finally:
# undo all the environment variables set
try:
del os.environ["PYTORCH_TUNABLEOP_BLAS_LOG"]
except KeyError:
pass
# clean up, remove any files that were generated
for filename in [untuned_filename, result_filename]:
try:
os.remove(filename)
# NB: The file is locked on Windows
except (FileNotFoundError, PermissionError):
pass
@dtypes(torch.float, torch.complex64)
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)