[ROCm][TunableOp] resolve the rocBLAS version dynamically (#147363)

Dynamically gets rocBLAS version instead of relying on some preprocessing-time definitions which may be stale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147363
Approved by: https://github.com/pruthvistony, https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
Arash Pakbin 2025-02-21 06:50:19 +00:00 committed by PyTorch MergeBot
parent 86ae672b6a
commit c74b59fc1f
2 changed files with 17 additions and 13 deletions

View File

@ -227,15 +227,10 @@ TuningResultsValidator::TuningResultsValidator() {
}
// rocblas
{
#define STRINGIFY(s) #s
#define XSTRINGIFY(s) STRINGIFY(s)
std::string rocblas_version = c10::str(
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
#undef XSTRINGIFY
#undef STRINGIFY
size_t rocblas_version_size;
rocblas_get_version_string_size(&rocblas_version_size);
std::string rocblas_version(rocblas_version_size - 1, '\0');
rocblas_get_version_string(rocblas_version.data(), rocblas_version_size);
RegisterValidator(
"ROCBLAS_VERSION",
[rocblas_version]() { return rocblas_version; },

View File

@ -82,6 +82,13 @@ def tunableop_matmul(device, dtype):
C = torch.matmul(A, B)
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
def get_tunableop_validators():
assert len(torch.cuda.tunable.get_validators()) > 0
validators = {}
for key, value in torch.cuda.tunable.get_validators():
validators[key] = value
return validators
class TestLinalg(TestCase):
def setUp(self):
super(self.__class__, self).setUp()
@ -4603,10 +4610,7 @@ class TestLinalg(TestCase):
filename3 = "tunableop_results_tmp2.csv"
ordinal = torch.cuda.current_device()
assert filename1 == f"tunableop_results{ordinal}.csv"
assert len(torch.cuda.tunable.get_validators()) > 0
validators = {}
for key, value in torch.cuda.tunable.get_validators():
validators[key] = value
validators = get_tunableop_validators()
if torch.version.hip:
assert "HIPBLASLT_VERSION" in validators
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
@ -4948,6 +4952,11 @@ class TestLinalg(TestCase):
C = torch.matmul(A, B)
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
validators = get_tunableop_validators()
self.assertTrue("ROCBLAS_VERSION" in validators)
# format: [major].[minor].[patch].[tweak].[commit id]
self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"]))
# disable TunableOp
torch.cuda.tunable.enable(False)