mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 39bfba3f56.
Reverted https://github.com/pytorch/pytorch/pull/137427 on behalf of https://github.com/jcaip due to this PR breaks AO tests ([comment](https://github.com/pytorch/pytorch/pull/137427#issuecomment-2435906592))
58 lines
1.2 KiB
Python
58 lines
1.2 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = [
|
|
"version",
|
|
"is_available",
|
|
"get_max_alg_id",
|
|
]
|
|
|
|
try:
|
|
from torch._C import _cusparselt
|
|
except ImportError:
|
|
_cusparselt = None # type: ignore[assignment]
|
|
|
|
__cusparselt_version: Optional[int] = None
|
|
__MAX_ALG_ID: Optional[int] = None
|
|
|
|
if _cusparselt is not None:
|
|
|
|
def _init():
|
|
global __cusparselt_version
|
|
global __MAX_ALG_ID
|
|
if __cusparselt_version is None:
|
|
__cusparselt_version = _cusparselt.getVersionInt()
|
|
if __cusparselt_version == 400:
|
|
__MAX_ALG_ID = 4
|
|
elif __cusparselt_version == 502:
|
|
__MAX_ALG_ID = 5
|
|
elif __cusparselt_version == 602:
|
|
__MAX_ALG_ID = 37
|
|
return True
|
|
|
|
else:
|
|
|
|
def _init():
|
|
return False
|
|
|
|
|
|
def version() -> Optional[int]:
|
|
"""Return the version of cuSPARSELt"""
|
|
if not _init():
|
|
return None
|
|
return __cusparselt_version
|
|
|
|
|
|
def is_available() -> bool:
|
|
r"""Return a bool indicating if cuSPARSELt is currently available."""
|
|
return torch._C._has_cusparselt
|
|
|
|
|
|
def get_max_alg_id() -> Optional[int]:
|
|
if not _init():
|
|
return None
|
|
return __MAX_ALG_ID
|