mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Ensrues pyrefly ignores only silence one error code. After this, only ~40 files left to clean up . pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166448 Approved by: https://github.com/Skylion007
238 lines
7.3 KiB
Python
238 lines
7.3 KiB
Python
# mypy: allow-untyped-defs
|
|
import os
|
|
import sys
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch.backends import (
|
|
__allow_nonbracketed_mutation,
|
|
_FP32Precision,
|
|
_get_fp32_precision_getter,
|
|
_set_fp32_precision_setter,
|
|
ContextProp,
|
|
PropModule,
|
|
)
|
|
|
|
|
|
try:
|
|
from torch._C import _cudnn
|
|
except ImportError:
|
|
_cudnn = None # type: ignore[assignment]
|
|
|
|
# Write:
|
|
#
|
|
# torch.backends.cudnn.enabled = False
|
|
#
|
|
# to globally disable CuDNN/MIOpen
|
|
|
|
__cudnn_version: Optional[int] = None
|
|
|
|
if _cudnn is not None:
|
|
|
|
def _init():
|
|
global __cudnn_version
|
|
if __cudnn_version is None:
|
|
# pyrefly: ignore [missing-attribute]
|
|
__cudnn_version = _cudnn.getVersionInt()
|
|
# pyrefly: ignore [missing-attribute]
|
|
runtime_version = _cudnn.getRuntimeVersion()
|
|
# pyrefly: ignore [missing-attribute]
|
|
compile_version = _cudnn.getCompileVersion()
|
|
runtime_major, runtime_minor, _ = runtime_version
|
|
compile_major, compile_minor, _ = compile_version
|
|
# Different major versions are always incompatible
|
|
# Starting with cuDNN 7, minor versions are backwards-compatible
|
|
# Not sure about MIOpen (ROCm), so always do a strict check
|
|
if runtime_major != compile_major:
|
|
cudnn_compatible = False
|
|
# pyrefly: ignore [missing-attribute]
|
|
elif runtime_major < 7 or not _cudnn.is_cuda:
|
|
cudnn_compatible = runtime_minor == compile_minor
|
|
else:
|
|
cudnn_compatible = runtime_minor >= compile_minor
|
|
if not cudnn_compatible:
|
|
if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
|
|
return True
|
|
base_error_msg = (
|
|
f"cuDNN version incompatibility: "
|
|
f"PyTorch was compiled against {compile_version} "
|
|
f"but found runtime version {runtime_version}. "
|
|
f"PyTorch already comes bundled with cuDNN. "
|
|
f"One option to resolving this error is to ensure PyTorch "
|
|
f"can find the bundled cuDNN. "
|
|
)
|
|
|
|
if "LD_LIBRARY_PATH" in os.environ:
|
|
ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
|
|
if any(
|
|
substring in ld_library_path for substring in ["cuda", "cudnn"]
|
|
):
|
|
raise RuntimeError(
|
|
f"{base_error_msg}"
|
|
f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. "
|
|
f"Please either remove it from the path or install cudnn {compile_version}"
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"{base_error_msg}"
|
|
f"one possibility is that there is a "
|
|
f"conflicting cuDNN in LD_LIBRARY_PATH."
|
|
)
|
|
else:
|
|
raise RuntimeError(base_error_msg)
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
def _init():
|
|
return False
|
|
|
|
|
|
def version():
|
|
"""Return the version of cuDNN."""
|
|
if not _init():
|
|
return None
|
|
return __cudnn_version
|
|
|
|
|
|
CUDNN_TENSOR_DTYPES = {
|
|
torch.half,
|
|
torch.float,
|
|
torch.double,
|
|
}
|
|
|
|
|
|
def is_available():
|
|
r"""Return a bool indicating if CUDNN is currently available."""
|
|
return torch._C._has_cudnn
|
|
|
|
|
|
def is_acceptable(tensor):
|
|
if not torch._C._get_cudnn_enabled():
|
|
return False
|
|
if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES:
|
|
return False
|
|
if not is_available():
|
|
warnings.warn(
|
|
"PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
|
|
"PyTorch making sure the library is visible to the build system.",
|
|
stacklevel=2,
|
|
)
|
|
return False
|
|
if not _init():
|
|
warnings.warn(
|
|
"cuDNN/MIOpen library not found. Check your {libpath}".format(
|
|
libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get(
|
|
sys.platform, "LD_LIBRARY_PATH"
|
|
)
|
|
),
|
|
stacklevel=2,
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def set_flags(
|
|
_enabled=None,
|
|
_benchmark=None,
|
|
_benchmark_limit=None,
|
|
_deterministic=None,
|
|
_allow_tf32=None,
|
|
_fp32_precision="none",
|
|
):
|
|
orig_flags = (
|
|
torch._C._get_cudnn_enabled(),
|
|
torch._C._get_cudnn_benchmark(),
|
|
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
|
|
torch._C._get_cudnn_deterministic(),
|
|
torch._C._get_cudnn_allow_tf32(),
|
|
torch._C._get_fp32_precision_getter("cuda", "all"),
|
|
)
|
|
if _enabled is not None:
|
|
torch._C._set_cudnn_enabled(_enabled)
|
|
if _benchmark is not None:
|
|
torch._C._set_cudnn_benchmark(_benchmark)
|
|
if _benchmark_limit is not None and is_available():
|
|
torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
|
|
if _deterministic is not None:
|
|
torch._C._set_cudnn_deterministic(_deterministic)
|
|
if _allow_tf32 is not None:
|
|
torch._C._set_cudnn_allow_tf32(_allow_tf32)
|
|
if _fp32_precision is not None:
|
|
torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision)
|
|
return orig_flags
|
|
|
|
|
|
@contextmanager
|
|
def flags(
|
|
enabled=False,
|
|
benchmark=False,
|
|
benchmark_limit=10,
|
|
deterministic=False,
|
|
allow_tf32=True,
|
|
fp32_precision="none",
|
|
):
|
|
with __allow_nonbracketed_mutation():
|
|
orig_flags = set_flags(
|
|
enabled,
|
|
benchmark,
|
|
benchmark_limit,
|
|
deterministic,
|
|
allow_tf32,
|
|
fp32_precision,
|
|
)
|
|
try:
|
|
yield
|
|
finally:
|
|
# recover the previous values
|
|
with __allow_nonbracketed_mutation():
|
|
set_flags(*orig_flags)
|
|
|
|
|
|
# The magic here is to allow us to intercept code like this:
|
|
#
|
|
# torch.backends.<cudnn|mkldnn>.enabled = True
|
|
|
|
|
|
class CudnnModule(PropModule):
|
|
def __init__(self, m, name):
|
|
super().__init__(m, name)
|
|
|
|
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
|
|
deterministic = ContextProp(
|
|
torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic
|
|
)
|
|
benchmark = ContextProp(
|
|
torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark
|
|
)
|
|
benchmark_limit = None
|
|
if is_available():
|
|
benchmark_limit = ContextProp(
|
|
torch._C._cuda_get_cudnn_benchmark_limit,
|
|
torch._C._cuda_set_cudnn_benchmark_limit,
|
|
)
|
|
allow_tf32 = ContextProp(
|
|
torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
|
|
)
|
|
conv = _FP32Precision("cuda", "conv")
|
|
rnn = _FP32Precision("cuda", "rnn")
|
|
fp32_precision = ContextProp(
|
|
_get_fp32_precision_getter("cuda", "all"),
|
|
_set_fp32_precision_setter("cuda", "all"),
|
|
)
|
|
|
|
|
|
# This is the sys.modules replacement trick, see
|
|
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
|
|
sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
|
|
|
|
# Add type annotation for the replaced module
|
|
enabled: bool
|
|
deterministic: bool
|
|
benchmark: bool
|
|
allow_tf32: bool
|
|
benchmark_limit: int
|