mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm][TunableOp] resolve the rocBLAS version dynamically (#147363)
Dynamically gets rocBLAS version instead of relying on some preprocessing-time definitions which may be stale. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147363 Approved by: https://github.com/pruthvistony, https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
parent
86ae672b6a
commit
c74b59fc1f
|
|
@ -227,15 +227,10 @@ TuningResultsValidator::TuningResultsValidator() {
|
|||
}
|
||||
// rocblas
|
||||
{
|
||||
#define STRINGIFY(s) #s
|
||||
#define XSTRINGIFY(s) STRINGIFY(s)
|
||||
std::string rocblas_version = c10::str(
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
||||
#undef XSTRINGIFY
|
||||
#undef STRINGIFY
|
||||
size_t rocblas_version_size;
|
||||
rocblas_get_version_string_size(&rocblas_version_size);
|
||||
std::string rocblas_version(rocblas_version_size - 1, '\0');
|
||||
rocblas_get_version_string(rocblas_version.data(), rocblas_version_size);
|
||||
RegisterValidator(
|
||||
"ROCBLAS_VERSION",
|
||||
[rocblas_version]() { return rocblas_version; },
|
||||
|
|
|
|||
|
|
@ -82,6 +82,13 @@ def tunableop_matmul(device, dtype):
|
|||
C = torch.matmul(A, B)
|
||||
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
|
||||
|
||||
def get_tunableop_validators():
|
||||
assert len(torch.cuda.tunable.get_validators()) > 0
|
||||
validators = {}
|
||||
for key, value in torch.cuda.tunable.get_validators():
|
||||
validators[key] = value
|
||||
return validators
|
||||
|
||||
class TestLinalg(TestCase):
|
||||
def setUp(self):
|
||||
super(self.__class__, self).setUp()
|
||||
|
|
@ -4603,10 +4610,7 @@ class TestLinalg(TestCase):
|
|||
filename3 = "tunableop_results_tmp2.csv"
|
||||
ordinal = torch.cuda.current_device()
|
||||
assert filename1 == f"tunableop_results{ordinal}.csv"
|
||||
assert len(torch.cuda.tunable.get_validators()) > 0
|
||||
validators = {}
|
||||
for key, value in torch.cuda.tunable.get_validators():
|
||||
validators[key] = value
|
||||
validators = get_tunableop_validators()
|
||||
if torch.version.hip:
|
||||
assert "HIPBLASLT_VERSION" in validators
|
||||
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
|
||||
|
|
@ -4948,6 +4952,11 @@ class TestLinalg(TestCase):
|
|||
C = torch.matmul(A, B)
|
||||
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
|
||||
|
||||
validators = get_tunableop_validators()
|
||||
self.assertTrue("ROCBLAS_VERSION" in validators)
|
||||
# format: [major].[minor].[patch].[tweak].[commit id]
|
||||
self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"]))
|
||||
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user