[ROCm][TunableOp] resolve the rocBLAS version dynamically (#147363)

Dynamically gets rocBLAS version instead of relying on some preprocessing-time definitions which may be stale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147363
Approved by: https://github.com/pruthvistony, https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
Arash Pakbin 2025-02-21 06:50:19 +00:00 committed by PyTorch MergeBot
parent 86ae672b6a
commit c74b59fc1f
2 changed files with 17 additions and 13 deletions

View File

@ -227,15 +227,10 @@ TuningResultsValidator::TuningResultsValidator() {
} }
// rocblas // rocblas
{ {
#define STRINGIFY(s) #s size_t rocblas_version_size;
#define XSTRINGIFY(s) STRINGIFY(s) rocblas_get_version_string_size(&rocblas_version_size);
std::string rocblas_version = c10::str( std::string rocblas_version(rocblas_version_size - 1, '\0');
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", rocblas_get_version_string(rocblas_version.data(), rocblas_version_size);
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
#undef XSTRINGIFY
#undef STRINGIFY
RegisterValidator( RegisterValidator(
"ROCBLAS_VERSION", "ROCBLAS_VERSION",
[rocblas_version]() { return rocblas_version; }, [rocblas_version]() { return rocblas_version; },

View File

@ -82,6 +82,13 @@ def tunableop_matmul(device, dtype):
C = torch.matmul(A, B) C = torch.matmul(A, B)
del os.environ["PYTORCH_TUNABLEOP_ENABLED"] del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
def get_tunableop_validators():
assert len(torch.cuda.tunable.get_validators()) > 0
validators = {}
for key, value in torch.cuda.tunable.get_validators():
validators[key] = value
return validators
class TestLinalg(TestCase): class TestLinalg(TestCase):
def setUp(self): def setUp(self):
super(self.__class__, self).setUp() super(self.__class__, self).setUp()
@ -4603,10 +4610,7 @@ class TestLinalg(TestCase):
filename3 = "tunableop_results_tmp2.csv" filename3 = "tunableop_results_tmp2.csv"
ordinal = torch.cuda.current_device() ordinal = torch.cuda.current_device()
assert filename1 == f"tunableop_results{ordinal}.csv" assert filename1 == f"tunableop_results{ordinal}.csv"
assert len(torch.cuda.tunable.get_validators()) > 0 validators = get_tunableop_validators()
validators = {}
for key, value in torch.cuda.tunable.get_validators():
validators[key] = value
if torch.version.hip: if torch.version.hip:
assert "HIPBLASLT_VERSION" in validators assert "HIPBLASLT_VERSION" in validators
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"]) assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
@ -4948,6 +4952,11 @@ class TestLinalg(TestCase):
C = torch.matmul(A, B) C = torch.matmul(A, B)
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines) self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
validators = get_tunableop_validators()
self.assertTrue("ROCBLAS_VERSION" in validators)
# format: [major].[minor].[patch].[tweak].[commit id]
self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"]))
# disable TunableOp # disable TunableOp
torch.cuda.tunable.enable(False) torch.cuda.tunable.enable(False)