Migrate from lru_cache to cache (#155613)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
This commit is contained in:
Oguz Ulgen 2025-06-11 08:59:26 -07:00 committed by PyTorch MergeBot
parent f80a61adf5
commit d1947a8707
70 changed files with 157 additions and 157 deletions

View File

@ -560,7 +560,7 @@ def nothing(f):
return f return f
@functools.lru_cache(None) @functools.cache
def patch_torch_manual_seed(): def patch_torch_manual_seed():
"""Make torch manual seed deterministic. Helps with accuracy testing.""" """Make torch manual seed deterministic. Helps with accuracy testing."""

View File

@ -135,7 +135,7 @@ def contains_tensor_types(type):
) )
@functools.lru_cache(None) @functools.cache
def non_compute_operator(op): def non_compute_operator(op):
schema = op._schema schema = op._schema

View File

@ -550,7 +550,7 @@ def build_summary(args):
gh_fh.write(comment) gh_fh.write(comment)
@functools.lru_cache(None) @functools.cache
def archive_data(archive_name): def archive_data(archive_name):
if archive_name is not None: if archive_name is not None:
prefix_match = re.search(r"\w+(?=_performance)", archive_name) prefix_match = re.search(r"\w+(?=_performance)", archive_name)
@ -570,7 +570,7 @@ def archive_data(archive_name):
return day, prefix return day, prefix
@functools.lru_cache(None) @functools.cache
def default_archive_name(dtype): def default_archive_name(dtype):
_, prefix = archive_data(None) _, prefix = archive_data(None)
return f"{prefix}_performance_{dtype}_{randint(100, 999)}" return f"{prefix}_performance_{dtype}_{randint(100, 999)}"

View File

@ -154,7 +154,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
return sorted(backends) return sorted(backends)
@functools.lru_cache(None) @functools.cache
def _lazy_import(): def _lazy_import():
from .. import backends from .. import backends
from ..utils import import_submodule from ..utils import import_submodule
@ -168,7 +168,7 @@ def _lazy_import():
_discover_entrypoint_backends() _discover_entrypoint_backends()
@functools.lru_cache(None) @functools.cache
def _discover_entrypoint_backends(): def _discover_entrypoint_backends():
# importing here so it will pick up the mocked version in test_backends.py # importing here so it will pick up the mocked version in test_backends.py
from importlib.metadata import entry_points from importlib.metadata import entry_points

View File

@ -201,7 +201,7 @@ def has_tvm():
return False return False
@functools.lru_cache(None) @functools.cache
def llvm_target(): def llvm_target():
if sys.platform == "linux": if sys.platform == "linux":
cpuinfo = open("/proc/cpuinfo").read() cpuinfo = open("/proc/cpuinfo").read()

View File

@ -247,7 +247,7 @@ class NNModuleToString:
return model_str return model_str
@functools.lru_cache(None) # subprocess is expensive @functools.cache # subprocess is expensive
def _cuda_system_info_comment(): def _cuda_system_info_comment():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n" return "# torch.cuda.is_available()==False, no GPU info collected\n"

View File

@ -2013,7 +2013,7 @@ def _optimize_assert(
class TorchPatcher: class TorchPatcher:
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def patch(): def patch():
# A better way to disable the following would be decorate the source # A better way to disable the following would be decorate the source
# functions with @torch._disable_dynamo. However, this causes issues # functions with @torch._disable_dynamo. However, this causes issues

View File

@ -418,7 +418,7 @@ def from_numpy(a):
# For user stack printing # For user stack printing
@functools.lru_cache(None) @functools.cache
def uninteresting_files(): def uninteresting_files():
import torch._dynamo.external_utils import torch._dynamo.external_utils
import torch._dynamo.polyfills import torch._dynamo.polyfills
@ -623,7 +623,7 @@ class GuardManagerType(enum.Enum):
DICT_GUARD_MANAGER = 2 DICT_GUARD_MANAGER = 2
@functools.lru_cache(None) @functools.cache
def code_framelocals_names_reversed_cached(code: types.CodeType): def code_framelocals_names_reversed_cached(code: types.CodeType):
return list(reversed(code_framelocals_names(code))) return list(reversed(code_framelocals_names(code)))

View File

@ -33,7 +33,7 @@ def get_loggers() -> list[logging.Logger]:
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time) # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
# so that step numbers are initialized properly. e.g.: # so that step numbers are initialized properly. e.g.:
# @functools.lru_cache(None) # @functools.cache
# def _step_logger(): # def _step_logger():
# return get_step_logger(logging.getLogger(...)) # return get_step_logger(logging.getLogger(...))

View File

@ -210,7 +210,7 @@ class VariableTrackerCache:
self.cache.clear() self.cache.clear()
@functools.lru_cache(None) @functools.cache
def _step_logger(): def _step_logger():
return torchdynamo_logging.get_step_logger(log) return torchdynamo_logging.get_step_logger(log)

View File

@ -331,7 +331,7 @@ class TensorifyState:
return len(cls.force_specializations) == 0 return len(cls.force_specializations) == 0
@functools.lru_cache(None) @functools.cache
def _step_logger(): def _step_logger():
return torchdynamo_logging.get_step_logger(log) return torchdynamo_logging.get_step_logger(log)

View File

@ -2896,7 +2896,7 @@ Generate the torch object - Dynamo tracing rule (the wrapping variable) map.
""" """
@functools.lru_cache(None) @functools.cache
def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
d: dict[Any, type[VariableTracker]] = {} d: dict[Any, type[VariableTracker]] = {}
for m in torch_name_rule_map: for m in torch_name_rule_map:
@ -2945,7 +2945,7 @@ Get all torch.Tensor methods which are allowed to be in graph functions.
""" """
@functools.lru_cache(None) @functools.cache
def get_tensor_method(): def get_tensor_method():
disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"}
s = set() s = set()
@ -3467,7 +3467,7 @@ assert sorted(set(MOD_SKIPLIST)) == MOD_SKIPLIST
MOD_SKIPLIST = set(MOD_SKIPLIST) MOD_SKIPLIST = set(MOD_SKIPLIST)
@functools.lru_cache(None) @functools.cache
def get_legacy_mod_inlinelist(): def get_legacy_mod_inlinelist():
inlinelist = { inlinelist = {
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
@ -3476,7 +3476,7 @@ def get_legacy_mod_inlinelist():
return inlinelist return inlinelist
@functools.lru_cache(None) @functools.cache
def get_mod_inlinelist(): def get_mod_inlinelist():
inlinelist = { inlinelist = {
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
@ -3485,7 +3485,7 @@ def get_mod_inlinelist():
return inlinelist return inlinelist
@functools.lru_cache(None) @functools.cache
def get_mod_skiplist(): def get_mod_skiplist():
skiplist = { skiplist = {
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
@ -3739,7 +3739,7 @@ def is_torch_inline_allowed(filename):
return any(filename.startswith(d) for d in get_mod_inlinelist()) return any(filename.startswith(d) for d in get_mod_inlinelist())
@functools.lru_cache(None) @functools.cache
def dynamo_dir(): def dynamo_dir():
import torch._dynamo import torch._dynamo

View File

@ -2369,7 +2369,7 @@ def is_safe_constant(v):
) )
@functools.lru_cache(None) @functools.cache
def common_constants(): def common_constants():
return { return {
# We zero-one specialize shapes, so specialize these constants # We zero-one specialize shapes, so specialize these constants
@ -3111,7 +3111,7 @@ seen_code_map = ExactWeakKeyDictionary()
# return same dir unless user changes config between calls # return same dir unless user changes config between calls
@functools.lru_cache(None) @functools.cache
def _get_debug_dir(root_dir): def _get_debug_dir(root_dir):
dir_name = ( dir_name = (
"run_" "run_"

View File

@ -475,7 +475,7 @@ class VariableBuilder:
return cls._type_dispatch_impl(config.trace_numpy) return cls._type_dispatch_impl(config.trace_numpy)
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def _type_dispatch_impl(cls, trace_numpy): def _type_dispatch_impl(cls, trace_numpy):
# NB: Careful not to close over self to avoid ref cycle from lru_cache # NB: Careful not to close over self to avoid ref cycle from lru_cache
entries = [ entries = [
@ -576,7 +576,7 @@ class VariableBuilder:
return self.tx.output.side_effects.track_mutable(value, result) return self.tx.output.side_effects.track_mutable(value, result)
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def _id_dispatch( def _id_dispatch(
cls, cls,
) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]:

View File

@ -169,7 +169,7 @@ class BuiltinVariable(VariableTracker):
return cls(value, source=source) return cls(value, source=source)
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _constant_fold_functions(): def _constant_fold_functions():
fns = { fns = {
abs, abs,
@ -239,7 +239,7 @@ class BuiltinVariable(VariableTracker):
return self.fn in self._constant_fold_functions() return self.fn in self._constant_fold_functions()
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _fx_graph_functions(): def _fx_graph_functions():
fns = { fns = {
operator.abs, operator.abs,
@ -285,7 +285,7 @@ class BuiltinVariable(VariableTracker):
return fns return fns
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _binops() -> dict[ def _binops() -> dict[
Callable[..., object], tuple[list[str], Callable[..., object]] Callable[..., object], tuple[list[str], Callable[..., object]]
]: ]:
@ -324,7 +324,7 @@ class BuiltinVariable(VariableTracker):
return fns return fns
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _binop_handlers(): def _binop_handlers():
# Multiple dispatch mechanism defining custom binop behavior for certain type # Multiple dispatch mechanism defining custom binop behavior for certain type
# combinations. Handlers are attempted in order, and will be used if the type checks # combinations. Handlers are attempted in order, and will be used if the type checks

View File

@ -1805,7 +1805,7 @@ class PolyfilledFunctionVariable(VariableTracker):
} }
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
return {} return {}

View File

@ -900,7 +900,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
self.nn_module_stack_source = source self.nn_module_stack_source = source
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _nn_module_method_ids(): def _nn_module_method_ids():
# Allow __setattr__ to fall through to base class handler # Allow __setattr__ to fall through to base class handler
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}

View File

@ -1040,7 +1040,7 @@ class TensorVariable(VariableTracker):
return wrap_fx_proxy(tx, proxy) return wrap_fx_proxy(tx, proxy)
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _warn_capture_scalar_outputs(): def _warn_capture_scalar_outputs():
user_stack = torch._guards.TracingContext.extract_stack() user_stack = torch._guards.TracingContext.extract_stack()
user_stack_formatted = "".join(traceback.format_list(user_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack))

View File

@ -170,7 +170,7 @@ constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need
constant_fold_functions = dict.fromkeys(constant_fold_functions) constant_fold_functions = dict.fromkeys(constant_fold_functions)
@functools.lru_cache(None) @functools.cache
def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]:
# Defined as a function to avoid circular import like torch.onnx # Defined as a function to avoid circular import like torch.onnx
return { return {
@ -197,7 +197,7 @@ dispatch_key_set_functions = {
} }
@functools.lru_cache(None) @functools.cache
def get_overridable_functions(): def get_overridable_functions():
from itertools import chain from itertools import chain
@ -432,7 +432,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return self.value return self.value
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _get_handlers(): def _get_handlers():
"""Build a dict from function -> method to handle it so that we are O(1) """Build a dict from function -> method to handle it so that we are O(1)
in terms of the number of function with special handling.""" in terms of the number of function with special handling."""

View File

@ -199,7 +199,7 @@ banned_attrs = [
] ]
@functools.lru_cache(None) @functools.cache
def get_prev_stack_var_name(): def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id from ..bytecode_transformation import unique_id

View File

@ -149,7 +149,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
return f"{self.__class__.__name__}({self.value})" return f"{self.__class__.__name__}({self.value})"
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _constant_fold_classes(): def _constant_fold_classes():
return { return {
torch.device, torch.device,
@ -159,7 +159,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
} }
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _in_graph_classes(): def _in_graph_classes():
_in_graph_class_list = { _in_graph_class_list = {
torch.Tensor, torch.Tensor,
@ -177,7 +177,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
return set(tensortype_to_dtype.keys()) | _in_graph_class_list return set(tensortype_to_dtype.keys()) | _in_graph_class_list
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def supported_c_new_functions(): def supported_c_new_functions():
exceptions = [ exceptions = [
getattr(builtins, name).__new__ getattr(builtins, name).__new__
@ -843,7 +843,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
) )
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _supported_random_functions(): def _supported_random_functions():
fns = { fns = {
random.random, random.random,

View File

@ -1229,7 +1229,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
remote_cache.put(key, cache_data) remote_cache.put(key, cache_data)
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
""" """
Attempts to load the remote cache, returns None on error. Attempts to load the remote cache, returns None on error.

View File

@ -1052,7 +1052,7 @@ def _count_ops(graph: fx.Graph):
log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True))
@functools.lru_cache(None) @functools.cache
def pointwise_ops(): def pointwise_ops():
ops = [] ops = []
for attr_name in dir(torch.ops.aten): for attr_name in dir(torch.ops.aten):

View File

@ -874,7 +874,7 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
return f"{self.kernel_name=}" return f"{self.kernel_name=}"
@functools.lru_cache(None) @functools.cache
def get_tuning_process_pool() -> TuningProcessPool: def get_tuning_process_pool() -> TuningProcessPool:
pool = TuningProcessPool() pool = TuningProcessPool()
atexit.register(pool.shutdown) atexit.register(pool.shutdown)

View File

@ -176,7 +176,7 @@ def get_kernel_bin_format(device: str) -> str:
return "" return ""
@functools.lru_cache(None) @functools.cache
def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
return ( return (
Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"]))
@ -187,7 +187,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
class CacheBase: class CacheBase:
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def get_system() -> dict[str, Any]: def get_system() -> dict[str, Any]:
try: try:
from triton.compiler.compiler import triton_key from triton.compiler.compiler import triton_key
@ -226,7 +226,7 @@ class CacheBase:
@staticmethod @staticmethod
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
@functools.lru_cache(None) @functools.cache
def get_local_cache_path() -> Path: def get_local_cache_path() -> Path:
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
@ -280,7 +280,7 @@ class LocalCache(CacheBase):
class PersistentCache(CacheBase): class PersistentCache(CacheBase):
@functools.lru_cache(None) # noqa: B019 @functools.cache # noqa: B019
def get_global_cache(self) -> dict[str, Any]: def get_global_cache(self) -> dict[str, Any]:
global_cache_path = self.get_global_cache_path() global_cache_path = self.get_global_cache_path()
if global_cache_path is None or not global_cache_path.is_file(): if global_cache_path is None or not global_cache_path.is_file():
@ -1567,7 +1567,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
pass pass
@functools.lru_cache(None) @functools.cache
def split_aot_inductor_output_path(path: str) -> tuple[str, str]: def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
"""Returns the path where the AOT Inductor compiled kernels are stored.""" """Returns the path where the AOT Inductor compiled kernels are stored."""
if path.endswith(".so"): if path.endswith(".so"):
@ -2283,7 +2283,7 @@ _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") _HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
@functools.lru_cache(None) @functools.cache
def _precompile_header( def _precompile_header(
header: str, header: str,
hashable_cmd_line: str, hashable_cmd_line: str,
@ -2969,7 +2969,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
return glue_code return glue_code
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def config_hash(cls) -> str: def config_hash(cls) -> str:
command_gen = CppBuilder( command_gen = CppBuilder(
name="O", name="O",
@ -3013,7 +3013,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
raise RuntimeError(errmsg) raise RuntimeError(errmsg)
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def find_libautoschedule(name: str) -> str: def find_libautoschedule(name: str) -> str:
sofile = f"libautoschedule_{name.lower()}.so" sofile = f"libautoschedule_{name.lower()}.so"
if "HALIDE_LIB" in os.environ: if "HALIDE_LIB" in os.environ:
@ -3026,7 +3026,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
return HalideCodeCache._search_for_file(sofile, errmsg) return HalideCodeCache._search_for_file(sofile, errmsg)
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def find_header(name: str) -> str: def find_header(name: str) -> str:
if "HALIDE_INCLUDE" in os.environ: if "HALIDE_INCLUDE" in os.environ:
path = os.path.join(os.environ["HALIDE_INCLUDE"], name) path = os.path.join(os.environ["HALIDE_INCLUDE"], name)
@ -3300,7 +3300,7 @@ class PyCodeCache:
cls.modules_no_attr.clear() cls.modules_no_attr.clear()
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def stack_frames_for_code( def stack_frames_for_code(
cls, path: str, lineno: int cls, path: str, lineno: int
) -> Optional[list[dict[str, Any]]]: ) -> Optional[list[dict[str, Any]]]:

View File

@ -441,7 +441,7 @@ def get_wrapper_codegen_for_device(
return None return None
@functools.lru_cache(None) @functools.cache
def init_backend_registration() -> None: def init_backend_registration() -> None:
from .cpp import CppScheduling from .cpp import CppScheduling
from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_cpu import CppWrapperCpu
@ -2223,7 +2223,7 @@ class OptimizationContext:
ops_name: str = "" ops_name: str = ""
@functools.lru_cache(None) @functools.cache
def jinja2_env() -> Any: def jinja2_env() -> Any:
try: try:
import jinja2 import jinja2

View File

@ -86,7 +86,7 @@ from .cpp_utils import (
_IS_WINDOWS = sys.platform == "win32" _IS_WINDOWS = sys.platform == "win32"
@functools.lru_cache(None) @functools.cache
def get_export_declaration(): def get_export_declaration():
return "__declspec(dllexport)" if _IS_WINDOWS else "" return "__declspec(dllexport)" if _IS_WINDOWS else ""

View File

@ -274,12 +274,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
): ):
code = self.prefix code = self.prefix
@functools.lru_cache(None) @functools.cache
def sizeof(name): def sizeof(name):
self.codegen_input_size_var_decl(code, name) self.codegen_input_size_var_decl(code, name)
return f"{name}_size" return f"{name}_size"
@functools.lru_cache(None) @functools.cache
def strideof(name): def strideof(name):
self.codegen_input_stride_var_decl(code, name) self.codegen_input_stride_var_decl(code, name)
return f"{name}_stride" return f"{name}_stride"
@ -1469,7 +1469,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.used_cached_memory_formats.add(memory_format_str) self.used_cached_memory_formats.add(memory_format_str)
return f"cached_torch_memory_format_{memory_format_str}" return f"cached_torch_memory_format_{memory_format_str}"
@functools.lru_cache(None) # noqa: B019 @functools.cache # noqa: B019
def codegen_int_array_var( def codegen_int_array_var(
self, self,
int_array: str, int_array: str,

View File

@ -40,6 +40,6 @@ def get_cuda_version() -> Optional[str]:
return None return None
@functools.lru_cache(None) @functools.cache
def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool:
return nvcc_path is not None and shutil.which(nvcc_path) is not None return nvcc_path is not None and shutil.which(nvcc_path) is not None

View File

@ -48,7 +48,7 @@ def _generate_config_filename(request_key: str) -> str:
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
@functools.lru_cache(None) @functools.cache
def maybe_fetch_ops() -> Optional[list[Any]]: def maybe_fetch_ops() -> Optional[list[Any]]:
""" """
Fetch ops from databases. Fetch ops from databases.

View File

@ -5,7 +5,7 @@ import torch
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
@functools.lru_cache(None) @functools.cache
def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]:
""" """
Generate cutlass presets for the given CUDA arch. Generate cutlass presets for the given CUDA arch.

View File

@ -56,7 +56,7 @@ def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str:
return content return content
@functools.lru_cache(None) @functools.cache
def try_import_cutlass() -> bool: def try_import_cutlass() -> bool:
""" """
We want to support three ways of passing in CUTLASS: We want to support three ways of passing in CUTLASS:
@ -251,7 +251,7 @@ class CUTLASSArgs:
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
@functools.lru_cache(None) @functools.cache
def _gen_ops_cached(arch, version) -> dict[Any, Any]: def _gen_ops_cached(arch, version) -> dict[Any, Any]:
# Note: Cache needs to be specific for cuda architecture and version # Note: Cache needs to be specific for cuda architecture and version

View File

@ -95,7 +95,7 @@ class CKTileGemmOperation:
return asdict(self).items() return asdict(self).items()
@functools.lru_cache(None) @functools.cache
def ops(): def ops():
""" """
Generate the supported instance dataclasses Generate the supported instance dataclasses

View File

@ -416,7 +416,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
self.code_hash: Optional[str] = None self.code_hash: Optional[str] = None
# define this in a closure to make cache local to object # define this in a closure to make cache local to object
@functools.lru_cache(None) @functools.cache
def simplify_indexing(index: sympy.Expr): def simplify_indexing(index: sympy.Expr):
index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
for tree in self.range_trees: for tree in self.range_trees:

View File

@ -1311,7 +1311,7 @@ class TritonKernelOverrides(TritonOverrides):
self._setup_libdevice_routing() self._setup_libdevice_routing()
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def _setup_libdevice_routing(cls): def _setup_libdevice_routing(cls):
"""Set up routing to libdevice implementations for fp64 inputs.""" """Set up routing to libdevice implementations for fp64 inputs."""

View File

@ -915,7 +915,7 @@ class PythonWrapperCodegen(CodeGen):
self.write_get_raw_stream self.write_get_raw_stream
) )
@functools.lru_cache(None) @functools.cache
def add_import_once(line: str) -> None: def add_import_once(line: str) -> None:
self.imports.writeline(line) self.imports.writeline(line)
if config.triton.autotune_at_compile_time: if config.triton.autotune_at_compile_time:
@ -1625,12 +1625,12 @@ class PythonWrapperCodegen(CodeGen):
): ):
code = self.prefix code = self.prefix
@functools.lru_cache(None) @functools.cache
def sizeof(name): def sizeof(name):
code.writeline(f"{name}_size = {name}.size()") code.writeline(f"{name}_size = {name}.size()")
return f"{name}_size" return f"{name}_size"
@functools.lru_cache(None) @functools.cache
def strideof(name): def strideof(name):
code.writeline(f"{name}_stride = {name}.stride()") code.writeline(f"{name}_stride = {name}.stride()")
return f"{name}_stride" return f"{name}_stride"

View File

@ -238,12 +238,12 @@ def record_original_output_strides(gm: GraphModule) -> None:
output_node.meta["original_output_strides"] = output_strides output_node.meta["original_output_strides"] = output_strides
@functools.lru_cache(None) @functools.cache
def _step_logger() -> Callable[..., None]: def _step_logger() -> Callable[..., None]:
return dynamo_logging.get_step_logger(log) return dynamo_logging.get_step_logger(log)
@functools.lru_cache(None) @functools.cache
def _warn_tf32_disabled() -> None: def _warn_tf32_disabled() -> None:
if ( if (
torch.cuda.is_available() torch.cuda.is_available()

View File

@ -614,7 +614,7 @@ class _OutOfProcessFxCompile(_SerializedFxCompile):
# And forward our collected logs. The cache is cleared when the outer # And forward our collected logs. The cache is cleared when the outer
# function exits. # function exits.
@functools.lru_cache(None) @functools.cache
def getLogger(name: str) -> logging.Logger: def getLogger(name: str) -> logging.Logger:
return logging.getLogger(name) return logging.getLogger(name)

View File

@ -79,7 +79,7 @@ def reset_counters() -> None:
call_counter_debug_info.clear() call_counter_debug_info.clear()
@functools.lru_cache(None) @functools.cache
def get_env_val(env_str: str) -> Optional[str]: def get_env_val(env_str: str) -> Optional[str]:
return os.environ.get(env_str, None) return os.environ.get(env_str, None)

View File

@ -127,7 +127,7 @@ def install_gcc_via_conda() -> str:
return cxx_path return cxx_path
@functools.lru_cache(None) @functools.cache
def check_compiler_exist_windows(compiler: str) -> None: def check_compiler_exist_windows(compiler: str) -> None:
""" """
Check if compiler is ready, in case end user not activate MSVC environment. Check if compiler is ready, in case end user not activate MSVC environment.
@ -200,13 +200,13 @@ def convert_cubin_to_obj(
return obj_file return obj_file
@functools.lru_cache(None) @functools.cache
def _is_apple_clang(cpp_compiler: str) -> bool: def _is_apple_clang(cpp_compiler: str) -> bool:
version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8")
return "Apple" in version_string.splitlines()[0] return "Apple" in version_string.splitlines()[0]
@functools.lru_cache(None) @functools.cache
def _is_clang(cpp_compiler: str) -> bool: def _is_clang(cpp_compiler: str) -> bool:
# Mac OS apple clang maybe named as gcc, need check compiler info. # Mac OS apple clang maybe named as gcc, need check compiler info.
if sys.platform == "darwin": if sys.platform == "darwin":
@ -221,7 +221,7 @@ def _is_clang(cpp_compiler: str) -> bool:
return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) return bool(re.search(r"(clang|clang\+\+)", cpp_compiler))
@functools.lru_cache(None) @functools.cache
def _is_gcc(cpp_compiler: str) -> bool: def _is_gcc(cpp_compiler: str) -> bool:
# Since "clang++" ends with "g++", the regex match below would validate on it. # Since "clang++" ends with "g++", the regex match below would validate on it.
if _is_clang(cpp_compiler): if _is_clang(cpp_compiler):
@ -229,7 +229,7 @@ def _is_gcc(cpp_compiler: str) -> bool:
return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler)) return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler))
@functools.lru_cache(None) @functools.cache
def _is_msvc_cl(cpp_compiler: str) -> bool: def _is_msvc_cl(cpp_compiler: str) -> bool:
if not _IS_WINDOWS: if not _IS_WINDOWS:
return False return False
@ -247,7 +247,7 @@ def _is_msvc_cl(cpp_compiler: str) -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def _is_intel_compiler(cpp_compiler: str) -> bool: def _is_intel_compiler(cpp_compiler: str) -> bool:
def _check_minimal_version(compiler_version: TorchVersion) -> None: def _check_minimal_version(compiler_version: TorchVersion) -> None:
""" """
@ -291,32 +291,32 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def is_gcc() -> bool: def is_gcc() -> bool:
return _is_gcc(get_cpp_compiler()) return _is_gcc(get_cpp_compiler())
@functools.lru_cache(None) @functools.cache
def is_clang() -> bool: def is_clang() -> bool:
return _is_clang(get_cpp_compiler()) return _is_clang(get_cpp_compiler())
@functools.lru_cache(None) @functools.cache
def is_intel_compiler() -> bool: def is_intel_compiler() -> bool:
return _is_intel_compiler(get_cpp_compiler()) return _is_intel_compiler(get_cpp_compiler())
@functools.lru_cache(None) @functools.cache
def is_apple_clang() -> bool: def is_apple_clang() -> bool:
return _is_apple_clang(get_cpp_compiler()) return _is_apple_clang(get_cpp_compiler())
@functools.lru_cache(None) @functools.cache
def is_msvc_cl() -> bool: def is_msvc_cl() -> bool:
return _is_msvc_cl(get_cpp_compiler()) return _is_msvc_cl(get_cpp_compiler())
@functools.lru_cache(None) @functools.cache
def get_compiler_version_info(compiler: str) -> str: def get_compiler_version_info(compiler: str) -> str:
env = os.environ.copy() env = os.environ.copy()
env["LC_ALL"] = "C" # Don't localize output env["LC_ALL"] = "C" # Don't localize output
@ -885,7 +885,7 @@ def _get_python_related_args() -> tuple[list[str], list[str]]:
return python_include_dirs, python_lib_path return python_include_dirs, python_lib_path
@functools.lru_cache(None) @functools.cache
def is_conda_llvm_openmp_installed() -> bool: def is_conda_llvm_openmp_installed() -> bool:
try: try:
command = "conda list llvm-openmp --json" command = "conda list llvm-openmp --json"
@ -895,7 +895,7 @@ def is_conda_llvm_openmp_installed() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def homebrew_libomp() -> tuple[bool, str]: def homebrew_libomp() -> tuple[bool, str]:
try: try:
# check if `brew` is installed # check if `brew` is installed
@ -916,7 +916,7 @@ def homebrew_libomp() -> tuple[bool, str]:
return False, "" return False, ""
@functools.lru_cache(None) @functools.cache
def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None:
try: try:
output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode(
@ -930,7 +930,7 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None:
pass pass
@functools.lru_cache(None) @functools.cache
def perload_icx_libomp_win(cpp_compiler: str) -> None: def perload_icx_libomp_win(cpp_compiler: str) -> None:
def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool:
try: try:

View File

@ -146,7 +146,7 @@ cdll.LoadLibrary("__lib_path__")
def __bool__(self) -> bool: def __bool__(self) -> bool:
return self.__bool__impl(config.cpp.vec_isa_ok) return self.__bool__impl(config.cpp.vec_isa_ok)
@functools.lru_cache(None) # noqa: B019 @functools.cache # noqa: B019
def __bool__impl(self, vec_isa_ok) -> bool: def __bool__impl(self, vec_isa_ok) -> bool:
if vec_isa_ok is not None: if vec_isa_ok is not None:
return vec_isa_ok return vec_isa_ok
@ -241,7 +241,7 @@ extern "C" void __amx_chk_kernel() {
} }
""" """
@functools.lru_cache(None) # noqa: B019 @functools.cache # noqa: B019
def __bool__(self) -> bool: def __bool__(self) -> bool:
if super().__bool__(): if super().__bool__():
if config.is_fbcode(): if config.is_fbcode():
@ -380,7 +380,7 @@ def get_isa_from_cpu_capability(
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
# might have too much redundant content that is useless for ISA check. Hence, # might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information. # we only cache some key isa information.
@functools.lru_cache(None) @functools.cache
def valid_vec_isa_list() -> list[VecISA]: def valid_vec_isa_list() -> list[VecISA]:
isa_list: list[VecISA] = [] isa_list: list[VecISA] = []
if sys.platform == "darwin" and platform.processor() == "arm": if sys.platform == "darwin" and platform.processor() == "arm":

View File

@ -50,7 +50,7 @@ BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
@functools.lru_cache(None) @functools.cache
def has_dot() -> bool: def has_dot() -> bool:
return shutil.which("dot") is not None return shutil.which("dot") is not None

View File

@ -856,7 +856,7 @@ def miopen_batch_norm(
) )
@functools.lru_cache(None) @functools.cache
def fast_random_decomps() -> dict[Any, Callable[..., Any]]: def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
return {**decompositions, **extra_random_decomps} return {**decompositions, **extra_random_decomps}

View File

@ -29,7 +29,7 @@ DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]
# So first decompose CSEVars -> tuple before calling this # So first decompose CSEVars -> tuple before calling this
@functools.lru_cache(None) @functools.cache
def get_promoted_dtype( def get_promoted_dtype(
*args: Sequence[tuple[torch.dtype, bool]], *args: Sequence[tuple[torch.dtype, bool]],
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None,

View File

@ -83,7 +83,7 @@ def recover_original_precision_folded_computation_ops(gm):
_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
@functools.lru_cache(None) @functools.cache
def binary_folding_init(): def binary_folding_init():
_conv_args = [Arg() for _ in range(9)] _conv_args = [Arg() for _ in range(9)]
_addmm_args = [Arg() for _ in range(3)] _addmm_args = [Arg() for _ in range(3)]

View File

@ -119,7 +119,7 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true):
) )
@functools.lru_cache(None) @functools.cache
def addmm_patterns_init(): def addmm_patterns_init():
device = next( device = next(
(gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu"

View File

@ -956,7 +956,7 @@ def _get_sfdp_patterns():
) )
@functools.lru_cache(None) @functools.cache
def _sfdp_init(): def _sfdp_init():
for key, register_replacement_kwargs in _get_sfdp_patterns(): for key, register_replacement_kwargs in _get_sfdp_patterns():
gen_register_replacement(key, **register_replacement_kwargs) gen_register_replacement(key, **register_replacement_kwargs)

View File

@ -12,7 +12,7 @@ from ..pattern_matcher import fwd_only, register_replacement
aten = torch.ops.aten aten = torch.ops.aten
@functools.lru_cache(None) @functools.cache
def _misc_patterns_init(): def _misc_patterns_init():
from .joint_graph import patterns as joint_graph_patterns from .joint_graph import patterns as joint_graph_patterns
from .post_grad import pass_patterns as post_grad_patterns_all from .post_grad import pass_patterns as post_grad_patterns_all

View File

@ -1398,7 +1398,7 @@ if torch._C._has_mkldnn:
user_node.replace_all_uses_with(node) user_node.replace_all_uses_with(node)
gm.graph.erase_node(user_node) gm.graph.erase_node(user_node)
@functools.lru_cache(None) @functools.cache
def _mkldnn_fusion_init(): def _mkldnn_fusion_init():
# TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
# Otherwise even the matmul or innerproduct can not be accelerated with acl # Otherwise even the matmul or innerproduct can not be accelerated with acl
@ -1414,7 +1414,7 @@ if torch._C._has_mkldnn:
_register_quantization_lowerings() _register_quantization_lowerings()
_register_woq_lowerings() _register_woq_lowerings()
@functools.lru_cache(None) @functools.cache
def _mkldnn_weight_pack_init(): def _mkldnn_weight_pack_init():
if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
_register_weight_pack_pass() _register_weight_pack_pass()

View File

@ -247,7 +247,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
return arithmetic_intensity > machine_balance return arithmetic_intensity > machine_balance
@functools.lru_cache(None) @functools.cache
def get_pad_cache() -> torch._inductor.codecache.LocalCache: def get_pad_cache() -> torch._inductor.codecache.LocalCache:
return torch._inductor.codecache.LocalCache() return torch._inductor.codecache.LocalCache()
@ -851,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
) )
@functools.lru_cache(None) @functools.cache
def _pad_mm_init() -> None: def _pad_mm_init() -> None:
from .joint_graph import patterns from .joint_graph import patterns

View File

@ -3478,7 +3478,7 @@ def _register_qlinear_binary_fusion():
) )
@functools.lru_cache(None) @functools.cache
def _register_quantization_weight_pack_pass(): def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16 # Step 1: Dequant promotion for int8-mixed-fp32/bf16
_register_dequant_promotion() _register_dequant_promotion()

View File

@ -511,7 +511,7 @@ scaled_mm_device_tma_template = TritonTemplate(
# prevent duplication registration of extern functions # prevent duplication registration of extern functions
@functools.lru_cache(None) @functools.cache
def lazy_register_extern_choice(fn): def lazy_register_extern_choice(fn):
return ExternKernelChoice(fn) return ExternKernelChoice(fn)
@ -1175,7 +1175,7 @@ def tuned_scaled_mm(
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
@functools.lru_cache(None) @functools.cache
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
props = torch.cuda.get_device_properties(index or 0) props = torch.cuda.get_device_properties(index or 0)
return props.major <= 7 return props.major <= 7

View File

@ -36,7 +36,7 @@ T = TypeVar("T")
class InterpreterShim(torch.fx.Interpreter): class InterpreterShim(torch.fx.Interpreter):
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def _dummy_gm(): def _dummy_gm():
return torch.fx.symbolic_trace(identity) return torch.fx.symbolic_trace(identity)

View File

@ -1879,7 +1879,7 @@ def fallback_handler(kernel, add_to_fallback_set=True):
return handler return handler
@functools.lru_cache(None) @functools.cache
def _warn_complex_not_supported(): def _warn_complex_not_supported():
warnings.warn( warnings.warn(
"Torchinductor does not support code generation for complex operators. Performance may be worse than eager." "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."

View File

@ -1010,7 +1010,7 @@ class PatternPrettyPrinter:
self.memoized_objs_pp: dict[PatternExpr, str] = {} self.memoized_objs_pp: dict[PatternExpr, str] = {}
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.cache
def run(obj: PatternExpr, output_name: str = "output") -> str: def run(obj: PatternExpr, output_name: str = "output") -> str:
""" """
Serializes obj to python code with obj written out to `output_name` Serializes obj to python code with obj written out to `output_name`
@ -2195,7 +2195,7 @@ def stable_topological_sort(graph: torch.fx.Graph) -> None:
def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
"""Wrapper around lazy init functions in fx_passes/""" """Wrapper around lazy init functions in fx_passes/"""
@functools.lru_cache(None) @functools.cache
@functools.wraps(fn) @functools.wraps(fn)
def lazy_init() -> Any: def lazy_init() -> Any:
counters_ref = counters["inductor"].copy() counters_ref = counters["inductor"].copy()

View File

@ -34,7 +34,7 @@ def _reload_python_module(
return mod return mod
@functools.lru_cache(None) @functools.cache
def _set_triton_ptxas_path() -> None: def _set_triton_ptxas_path() -> None:
if os.environ.get("TRITON_PTXAS_PATH") is not None: if os.environ.get("TRITON_PTXAS_PATH") is not None:
return return

View File

@ -136,7 +136,7 @@ class DeviceProperties(typing.NamedTuple):
warp_size: Optional[int] = None warp_size: Optional[int] = None
@classmethod @classmethod
@functools.lru_cache(None) @functools.cache
def create(cls, device) -> DeviceProperties: def create(cls, device) -> DeviceProperties:
import torch import torch
from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.device_interface import get_interface_for_device

View File

@ -1124,7 +1124,7 @@ class TritonTemplateKernel(TritonKernel):
] ]
@functools.lru_cache(None) @functools.cache
def _jinja2_env(): def _jinja2_env():
try: try:
import jinja2 import jinja2
@ -1726,7 +1726,7 @@ class ExternKernelChoice:
def call_name(self): def call_name(self):
return f"extern_kernels.{self.name}" return f"extern_kernels.{self.name}"
@functools.lru_cache(None) # noqa: B019 @functools.cache # noqa: B019
def hash_key(self): def hash_key(self):
fn = self.to_callable() fn = self.to_callable()
parts = [ parts = [
@ -1933,7 +1933,7 @@ class ExternKernelCaller(ChoiceCaller):
return f"extern_{self.choice.name}" return f"extern_{self.choice.name}"
@functools.lru_cache(None) @functools.cache
def get_mm_log_filename() -> Optional[str]: def get_mm_log_filename() -> Optional[str]:
mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None)
if not mm_file_name: if not mm_file_name:
@ -2052,7 +2052,7 @@ class NoValidChoicesError(RuntimeError):
pass pass
@functools.lru_cache(None) @functools.cache
def get_num_workers() -> int: def get_num_workers() -> int:
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
@ -2194,7 +2194,7 @@ class AlgorithmSelectorCache(PersistentCache):
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
return choices[0].output_node() return choices[0].output_node()
@functools.lru_cache(None) @functools.cache
def make_benchmark_fn(): def make_benchmark_fn():
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
@ -2506,7 +2506,7 @@ class AlgorithmSelectorCache(PersistentCache):
future.add_done_callback(on_complete) future.add_done_callback(on_complete)
futures[future] = c futures[future] = c
@functools.lru_cache(None) @functools.cache
@restore_stdout_stderr() @restore_stdout_stderr()
def wait_on_futures(): def wait_on_futures():
log.debug("Waiting on futures") log.debug("Waiting on futures")

View File

@ -91,7 +91,7 @@ T = TypeVar("T")
# defines here before import torch._dynamo is for avoiding circular import # defines here before import torch._dynamo is for avoiding circular import
# when get_gpu_type is imported from dynamo # when get_gpu_type is imported from dynamo
@functools.lru_cache(None) @functools.cache
def get_gpu_type() -> str: def get_gpu_type() -> str:
avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
assert len(avail_gpus) <= 1 assert len(avail_gpus) <= 1
@ -338,7 +338,7 @@ def do_bench_using_profiling(
return res return res
@functools.lru_cache(None) @functools.cache
def has_torchvision_roi_align() -> bool: def has_torchvision_roi_align() -> bool:
try: try:
from torchvision.ops import roi_align # noqa: F401 from torchvision.ops import roi_align # noqa: F401
@ -1384,7 +1384,7 @@ class DelayReplaceLine(DeferredLineBase):
return DelayReplaceLine(self.key, self.value_fn, line) return DelayReplaceLine(self.key, self.value_fn, line)
@functools.lru_cache(None) @functools.cache
def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
if isinstance(index_or_device, torch.device): if isinstance(index_or_device, torch.device):
device = index_or_device device = index_or_device
@ -1599,7 +1599,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
) )
@functools.lru_cache(None) @functools.cache
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
# If k is a sympy expression, we can't do any splitting # If k is a sympy expression, we can't do any splitting
if isinstance(k, sympy.Expr) and not k.is_number: if isinstance(k, sympy.Expr) and not k.is_number:
@ -1652,12 +1652,12 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
return best_splits[:k_splits_limit] return best_splits[:k_splits_limit]
@functools.lru_cache(None) @functools.cache
def _rocm_native_device_arch_name(device: str) -> str: def _rocm_native_device_arch_name(device: str) -> str:
return torch.cuda.get_device_properties(device).gcnArchName return torch.cuda.get_device_properties(device).gcnArchName
@functools.lru_cache(None) @functools.cache
def try_import_ck_lib() -> tuple[ def try_import_ck_lib() -> tuple[
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
]: ]:
@ -2106,7 +2106,7 @@ def parallel_num_threads() -> int:
return threads return threads
@functools.lru_cache(None) @functools.cache
def get_backend_num_stages() -> int: def get_backend_num_stages() -> int:
from .runtime.triton_helpers import get_backend_options from .runtime.triton_helpers import get_backend_options
@ -2114,7 +2114,7 @@ def get_backend_num_stages() -> int:
return options.get("num_stages", 2 if torch.version.hip else 3) return options.get("num_stages", 2 if torch.version.hip else 3)
@functools.lru_cache(None) @functools.cache
def get_device_tflops(dtype: torch.dtype) -> int: def get_device_tflops(dtype: torch.dtype) -> int:
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
@ -2142,7 +2142,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
return get_max_simd_tflops(torch.float32) return get_max_simd_tflops(torch.float32)
@functools.lru_cache(None) @functools.cache
def get_gpu_dram_gbps() -> int: def get_gpu_dram_gbps() -> int:
from triton.testing import get_dram_gbps from triton.testing import get_dram_gbps
@ -2862,7 +2862,7 @@ def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
) )
@functools.lru_cache(None) @functools.cache
def boolean_ops() -> tuple[str, ...]: def boolean_ops() -> tuple[str, ...]:
return ( return (
"isinf", "isinf",
@ -3051,7 +3051,7 @@ class TritonAttrsDescriptorVersion(enum.Enum):
V4_DICT = 4 # a raw dict V4_DICT = 4 # a raw dict
@functools.lru_cache(None) @functools.cache
def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion: def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
if importlib.util.find_spec("triton") is None: if importlib.util.find_spec("triton") is None:
return TritonAttrsDescriptorVersion.V0_NO_TRITON return TritonAttrsDescriptorVersion.V0_NO_TRITON

View File

@ -1143,7 +1143,7 @@ class LazyTraceHandler(logging.StreamHandler):
super().emit(record) super().emit(record)
@functools.lru_cache(None) @functools.cache
def warning_once(logger_obj, *args, **kwargs) -> None: def warning_once(logger_obj, *args, **kwargs) -> None:
""" """
This function is similar to `logger.warning()`, but will emit the warning with the same message only once This function is similar to `logger.warning()`, but will emit the warning with the same message only once

View File

@ -15,7 +15,7 @@ import torch.overrides
from torch._prims_common import torch_function_passthrough from torch._prims_common import torch_function_passthrough
@functools.lru_cache(None) @functools.cache
def torch_to_refs_map(): def torch_to_refs_map():
""" """
Mapping of torch API functions to torch._refs functions. Mapping of torch API functions to torch._refs functions.
@ -70,7 +70,7 @@ def torch_to_refs_map():
return r return r
@functools.lru_cache(None) @functools.cache
def all_prims(): def all_prims():
""" """
Set of all prim functions, e.g., torch._prims.add in all_prims() Set of all prim functions, e.g., torch._prims.add in all_prims()

View File

@ -114,7 +114,7 @@ def contains_tensor_types(type):
) )
@functools.lru_cache(None) @functools.cache
def _is_tensor_constructor(func: OpOverload): def _is_tensor_constructor(func: OpOverload):
assert isinstance(func, OpOverload) assert isinstance(func, OpOverload)
schema = func._schema schema = func._schema
@ -1077,7 +1077,7 @@ def fast_detach(fake_mode, x):
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
@functools.lru_cache(None) @functools.cache
def get_fast_op_impls(): def get_fast_op_impls():
import torch._refs import torch._refs

View File

@ -233,7 +233,7 @@ def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]:
return None return None
@functools.lru_cache(None) @functools.cache
def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo:
return torch._C._SchemaInfo(func._schema) return torch._C._SchemaInfo(func._schema)
@ -243,7 +243,7 @@ def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo:
# torch/_decomp/decompositions.py. # torch/_decomp/decompositions.py.
# decomps are used for aot autograd tracing so we would like to unify on their # decomps are used for aot autograd tracing so we would like to unify on their
# implementation and add additional testing to them # implementation and add additional testing to them
@functools.lru_cache(None) @functools.cache
def torch_decomp_decompositions(func: OpOverload) -> bool: def torch_decomp_decompositions(func: OpOverload) -> bool:
from torch._decomp import decomposition_table from torch._decomp import decomposition_table
@ -511,7 +511,7 @@ class FakeTensorConverter:
return out return out
@functools.lru_cache(None) @functools.cache
def init_gpu_context(device: torch.device) -> None: def init_gpu_context(device: torch.device) -> None:
# Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first
if torch.cuda.is_available() or torch.xpu.is_available(): if torch.cuda.is_available() or torch.xpu.is_available():

View File

@ -210,7 +210,7 @@ def is_fb_unit_test() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def max_clock_rate(): def max_clock_rate():
if not torch.version.hip: if not torch.version.hip:
from triton.testing import nvsmi from triton.testing import nvsmi

View File

@ -7379,7 +7379,7 @@ class ShapeEnv:
# Don't track this one. (Because this cache is inside this function the # Don't track this one. (Because this cache is inside this function the
# cache only lasts for the invocation of this function call) # cache only lasts for the invocation of this function call)
@functools.lru_cache(None) @functools.cache
def compute_concrete_val() -> sympy.Basic: def compute_concrete_val() -> sympy.Basic:
if hint is None: if hint is None:
# This is only ever called for expressions WITHOUT unbacked # This is only ever called for expressions WITHOUT unbacked

View File

@ -11,7 +11,7 @@ if TYPE_CHECKING:
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318 # TODO: Remove after https://github.com/huggingface/safetensors/pull/318
@functools.lru_cache(None) @functools.cache
def has_safetensors_and_transformers(): def has_safetensors_and_transformers():
try: try:
# safetensors is not an exporter requirement, but needed for some huggingface models # safetensors is not an exporter requirement, but needed for some huggingface models

View File

@ -98,7 +98,7 @@ def _disable_user_warnings(
return wrapper return wrapper
@functools.lru_cache(None) @functools.cache
@_disable_user_warnings @_disable_user_warnings
def get_ignored_functions() -> set[Callable]: def get_ignored_functions() -> set[Callable]:
""" """
@ -378,7 +378,7 @@ def get_ignored_functions() -> set[Callable]:
} }
@functools.lru_cache(None) @functools.cache
def get_default_nowrap_functions() -> set[Callable]: def get_default_nowrap_functions() -> set[Callable]:
""" """
Return public functions that do not wrap in a subclass when invoked by Return public functions that do not wrap in a subclass when invoked by
@ -404,7 +404,7 @@ def get_default_nowrap_functions() -> set[Callable]:
} }
@functools.lru_cache(None) @functools.cache
@_disable_user_warnings @_disable_user_warnings
def get_testing_overrides() -> dict[Callable, Callable]: def get_testing_overrides() -> dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions """Return a dict containing dummy overrides for all overridable functions
@ -1808,7 +1808,7 @@ has_torch_function_variadic = _add_docstr(
) )
@functools.lru_cache(None) @functools.cache
def _get_overridable_functions() -> tuple[ def _get_overridable_functions() -> tuple[
dict[Any, list[Callable]], dict[Callable, str] dict[Any, list[Callable]], dict[Callable, str]
]: ]:
@ -1929,7 +1929,7 @@ def resolve_name(f):
return _get_overridable_functions()[1].get(f) return _get_overridable_functions()[1].get(f)
@functools.lru_cache(None) @functools.cache
def _get_tensor_methods() -> set[Callable]: def _get_tensor_methods() -> set[Callable]:
"""Returns a set of the overridable methods on ``torch.Tensor``""" """Returns a set of the overridable methods on ``torch.Tensor``"""
overridable_funcs = get_overridable_functions() overridable_funcs = get_overridable_functions()

View File

@ -3,7 +3,7 @@ import functools
from torch.utils._triton import has_triton from torch.utils._triton import has_triton
@functools.lru_cache(None) @functools.cache
def has_helion_package() -> bool: def has_helion_package() -> bool:
try: try:
import helion # type: ignore[import-untyped, import-not-found] # noqa: F401 import helion # type: ignore[import-untyped, import-not-found] # noqa: F401
@ -12,6 +12,6 @@ def has_helion_package() -> bool:
return True return True
@functools.lru_cache(None) @functools.cache
def has_helion() -> bool: def has_helion() -> bool:
return has_helion_package() and has_triton() return has_helion_package() and has_triton()

View File

@ -51,7 +51,7 @@ log = logging.getLogger(__name__)
# TODO: Dedupe this with SYMPY_INTERP # TODO: Dedupe this with SYMPY_INTERP
@functools.lru_cache(None) @functools.cache
def handlers(): def handlers():
# TODO add CeilDiv (it doesn't appear in the index_expr) # TODO add CeilDiv (it doesn't appear in the index_expr)

View File

@ -3,7 +3,7 @@ import hashlib
from typing import Any from typing import Any
@functools.lru_cache(None) @functools.cache
def has_triton_package() -> bool: def has_triton_package() -> bool:
try: try:
from triton.compiler.compiler import triton_key from triton.compiler.compiler import triton_key
@ -15,7 +15,7 @@ def has_triton_package() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def _device_supports_tma() -> bool: def _device_supports_tma() -> bool:
import torch import torch
@ -26,7 +26,7 @@ def _device_supports_tma() -> bool:
) )
@functools.lru_cache(None) @functools.cache
def has_triton_experimental_host_tma() -> bool: def has_triton_experimental_host_tma() -> bool:
if has_triton_package(): if has_triton_package():
if _device_supports_tma(): if _device_supports_tma():
@ -43,7 +43,7 @@ def has_triton_experimental_host_tma() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def has_triton_tensor_descriptor_host_tma() -> bool: def has_triton_tensor_descriptor_host_tma() -> bool:
if has_triton_package(): if has_triton_package():
if _device_supports_tma(): if _device_supports_tma():
@ -59,12 +59,12 @@ def has_triton_tensor_descriptor_host_tma() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def has_triton_tma() -> bool: def has_triton_tma() -> bool:
return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma() return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
@functools.lru_cache(None) @functools.cache
def has_triton_tma_device() -> bool: def has_triton_tma_device() -> bool:
if has_triton_package(): if has_triton_package():
import torch import torch
@ -87,7 +87,7 @@ def has_triton_tma_device() -> bool:
return False return False
@functools.lru_cache(None) @functools.cache
def has_triton() -> bool: def has_triton() -> bool:
if not has_triton_package(): if not has_triton_package():
return False return False
@ -121,7 +121,7 @@ def has_triton() -> bool:
return is_device_compatible_with_triton() return is_device_compatible_with_triton()
@functools.lru_cache(None) @functools.cache
def triton_backend() -> Any: def triton_backend() -> Any:
from triton.compiler.compiler import make_backend from triton.compiler.compiler import make_backend
from triton.runtime.driver import driver from triton.runtime.driver import driver
@ -130,7 +130,7 @@ def triton_backend() -> Any:
return make_backend(target) return make_backend(target)
@functools.lru_cache(None) @functools.cache
def triton_hash_with_backend() -> str: def triton_hash_with_backend() -> str:
from triton.compiler.compiler import triton_key from triton.compiler.compiler import triton_key