mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm][TunableOp] resolve the rocBLAS version dynamically (#147363)
Dynamically gets rocBLAS version instead of relying on some preprocessing-time definitions which may be stale. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147363 Approved by: https://github.com/pruthvistony, https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
parent
86ae672b6a
commit
c74b59fc1f
|
|
@ -227,15 +227,10 @@ TuningResultsValidator::TuningResultsValidator() {
|
||||||
}
|
}
|
||||||
// rocblas
|
// rocblas
|
||||||
{
|
{
|
||||||
#define STRINGIFY(s) #s
|
size_t rocblas_version_size;
|
||||||
#define XSTRINGIFY(s) STRINGIFY(s)
|
rocblas_get_version_string_size(&rocblas_version_size);
|
||||||
std::string rocblas_version = c10::str(
|
std::string rocblas_version(rocblas_version_size - 1, '\0');
|
||||||
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
rocblas_get_version_string(rocblas_version.data(), rocblas_version_size);
|
||||||
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
|
||||||
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
|
||||||
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
|
||||||
#undef XSTRINGIFY
|
|
||||||
#undef STRINGIFY
|
|
||||||
RegisterValidator(
|
RegisterValidator(
|
||||||
"ROCBLAS_VERSION",
|
"ROCBLAS_VERSION",
|
||||||
[rocblas_version]() { return rocblas_version; },
|
[rocblas_version]() { return rocblas_version; },
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,13 @@ def tunableop_matmul(device, dtype):
|
||||||
C = torch.matmul(A, B)
|
C = torch.matmul(A, B)
|
||||||
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
|
del os.environ["PYTORCH_TUNABLEOP_ENABLED"]
|
||||||
|
|
||||||
|
def get_tunableop_validators():
|
||||||
|
assert len(torch.cuda.tunable.get_validators()) > 0
|
||||||
|
validators = {}
|
||||||
|
for key, value in torch.cuda.tunable.get_validators():
|
||||||
|
validators[key] = value
|
||||||
|
return validators
|
||||||
|
|
||||||
class TestLinalg(TestCase):
|
class TestLinalg(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(self.__class__, self).setUp()
|
super(self.__class__, self).setUp()
|
||||||
|
|
@ -4603,10 +4610,7 @@ class TestLinalg(TestCase):
|
||||||
filename3 = "tunableop_results_tmp2.csv"
|
filename3 = "tunableop_results_tmp2.csv"
|
||||||
ordinal = torch.cuda.current_device()
|
ordinal = torch.cuda.current_device()
|
||||||
assert filename1 == f"tunableop_results{ordinal}.csv"
|
assert filename1 == f"tunableop_results{ordinal}.csv"
|
||||||
assert len(torch.cuda.tunable.get_validators()) > 0
|
validators = get_tunableop_validators()
|
||||||
validators = {}
|
|
||||||
for key, value in torch.cuda.tunable.get_validators():
|
|
||||||
validators[key] = value
|
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
assert "HIPBLASLT_VERSION" in validators
|
assert "HIPBLASLT_VERSION" in validators
|
||||||
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
|
assert re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])
|
||||||
|
|
@ -4948,6 +4952,11 @@ class TestLinalg(TestCase):
|
||||||
C = torch.matmul(A, B)
|
C = torch.matmul(A, B)
|
||||||
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
|
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
|
||||||
|
|
||||||
|
validators = get_tunableop_validators()
|
||||||
|
self.assertTrue("ROCBLAS_VERSION" in validators)
|
||||||
|
# format: [major].[minor].[patch].[tweak].[commit id]
|
||||||
|
self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"]))
|
||||||
|
|
||||||
# disable TunableOp
|
# disable TunableOp
|
||||||
torch.cuda.tunable.enable(False)
|
torch.cuda.tunable.enable(False)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user