mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm][TunableOp] Unit test for TunableOp BLAS logging. (#148982)
Add unit test for new TunableOp BLAS logging feature. Requires this PR to be merged in first: https://github.com/pytorch/pytorch/pull/148979 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148982 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
71daeddde2
commit
37bb7f79c6
|
|
@ -78,12 +78,17 @@ def set_tunableop_defaults():
|
|||
torch.cuda.tunable.set_max_tuning_iterations(100)
|
||||
torch.cuda.tunable.set_rotating_buffer_size(-1)
|
||||
|
||||
def tunableop_matmul(device, dtype):
|
||||
def tunableop_matmul(device, dtype, offline=False):
|
||||
# Helper function to test TunableOp in a subprocess
|
||||
# requires helper function since lambda function
|
||||
# not supported by multiprocessing module
|
||||
import os
|
||||
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
|
||||
|
||||
if offline:
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
|
||||
torch.cuda.tunable.set_max_tuning_duration(1)
|
||||
A = torch.randn((17, 17), device=device, dtype=dtype)
|
||||
B = torch.randn((17, 17), device=device, dtype=dtype)
|
||||
|
|
@ -5661,6 +5666,108 @@ class TestLinalg(TestCase):
|
|||
except (FileNotFoundError, PermissionError):
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float16)
|
||||
def test_blaslog_tunableop(self, device, dtype):
|
||||
# Test that PYTORCH_TUNABLEOP_BLAS_LOG=1 gives
|
||||
# an additional column of data with the BLAS
|
||||
# parameters in offline and online tuning.
|
||||
#
|
||||
# We record GEMMs and then check that the
|
||||
# BLAS_PARAMS appear in
|
||||
# tunableop_untuned CSV file
|
||||
# and
|
||||
# tunableop_results CSV file
|
||||
#
|
||||
# NOTE: This is done in a subproceses
|
||||
# because in the main process
|
||||
# PYTORCH_TUNABLEOP_BLAS_LOG has already
|
||||
# been deactivated and its value is sticky
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
|
||||
set_tunableop_defaults()
|
||||
ordinal = torch.cuda.current_device()
|
||||
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
|
||||
# Test in try-finally block to avoid leaking state
|
||||
# if test is interrupted.
|
||||
try:
|
||||
os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1")
|
||||
|
||||
# Offline Tuning case in a subprocess
|
||||
|
||||
# force=True needed according to:
|
||||
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
|
||||
# This is because a different test in this process could have
|
||||
# already set the start method
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
self.assertTrue(os.path.getsize(untuned_filename) > 0)
|
||||
|
||||
# Check that the BLAS PARAMS are in the CSV file
|
||||
import csv
|
||||
with open(untuned_filename) as file:
|
||||
reader = csv.reader(file)
|
||||
first_row = next(reader)
|
||||
# Check for extra column
|
||||
self.assertGreater(len(first_row), 3)
|
||||
# Check for YAML entry to the right of
|
||||
# BLAS PARAMS
|
||||
self.assertTrue("{ function:" in first_row[2])
|
||||
|
||||
# Online tuning case in a subprocess
|
||||
|
||||
# force=True needed according to:
|
||||
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
|
||||
# This is because a different test in this process could have
|
||||
# already set the start method
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, False))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Check that there BLAS PARAMS are in the CSV file
|
||||
with open(result_filename) as file:
|
||||
reader = csv.reader(file)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(reader, None)
|
||||
# Check for extra column
|
||||
first_row = next(reader)
|
||||
self.assertGreater(len(first_row), 5)
|
||||
# Check for YAML entry to the right of
|
||||
# BLAS PARAMS
|
||||
self.assertTrue("{ function:" in first_row[4])
|
||||
|
||||
finally:
|
||||
# undo all the environment variables set
|
||||
try:
|
||||
del os.environ["PYTORCH_TUNABLEOP_BLAS_LOG"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# clean up, remove any files that were generated
|
||||
for filename in [untuned_filename, result_filename]:
|
||||
try:
|
||||
os.remove(filename)
|
||||
# NB: The file is locked on Windows
|
||||
except (FileNotFoundError, PermissionError):
|
||||
pass
|
||||
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user