From c74b59fc1f466c81efdb2d5400b971a7fbae847e Mon Sep 17 00:00:00 2001 From: Arash Pakbin Date: Fri, 21 Feb 2025 06:50:19 +0000 Subject: [PATCH] [ROCm][TunableOp] resolve the rocBLAS version dynamically (#147363) Dynamically gets rocBLAS version instead of relying on some preprocessing-time definitions which may be stale. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147363 Approved by: https://github.com/pruthvistony, https://github.com/naromero77amd, https://github.com/jeffdaily --- aten/src/ATen/cuda/tunable/Tunable.cpp | 13 ++++--------- test/test_linalg.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 03f7d77ac9e..c60c39d5a82 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -227,15 +227,10 @@ TuningResultsValidator::TuningResultsValidator() { } // rocblas { -#define STRINGIFY(s) #s -#define XSTRINGIFY(s) STRINGIFY(s) - std::string rocblas_version = c10::str( - XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", - XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); -#undef XSTRINGIFY -#undef STRINGIFY + size_t rocblas_version_size; + rocblas_get_version_string_size(&rocblas_version_size); + std::string rocblas_version(rocblas_version_size - 1, '\0'); + rocblas_get_version_string(rocblas_version.data(), rocblas_version_size); RegisterValidator( "ROCBLAS_VERSION", [rocblas_version]() { return rocblas_version; }, diff --git a/test/test_linalg.py b/test/test_linalg.py index 4aaf480b428..c40cec6df46 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -82,6 +82,13 @@ def tunableop_matmul(device, dtype): C = torch.matmul(A, B) del os.environ["PYTORCH_TUNABLEOP_ENABLED"] +def get_tunableop_validators(): + assert len(torch.cuda.tunable.get_validators()) > 0 + validators = {} + for key, value in torch.cuda.tunable.get_validators(): + validators[key] = value + return validators + class TestLinalg(TestCase): def setUp(self): super(self.__class__, self).setUp() @@ -4603,10 +4610,7 @@ class TestLinalg(TestCase): filename3 = "tunableop_results_tmp2.csv" ordinal = torch.cuda.current_device() assert filename1 == f"tunableop_results{ordinal}.csv" - assert len(torch.cuda.tunable.get_validators()) > 0 - validators = {} - for key, value in torch.cuda.tunable.get_validators(): - validators[key] = value + validators = get_tunableop_validators() if torch.version.hip: assert "HIPBLASLT_VERSION" in validators assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"]) @@ -4948,6 +4952,11 @@ class TestLinalg(TestCase): C = torch.matmul(A, B) self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines) + validators = get_tunableop_validators() + self.assertTrue("ROCBLAS_VERSION" in validators) + # format: [major].[minor].[patch].[tweak].[commit id] + self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"])) + # disable TunableOp torch.cuda.tunable.enable(False)