From d1947a87074c5db2568038878b1948ea3a33cc23 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Wed, 11 Jun 2025 08:59:26 -0700 Subject: [PATCH] Migrate from lru_cache to cache (#155613) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613 Approved by: https://github.com/ezyang ghstack dependencies: #155612 --- benchmarks/dynamo/common.py | 2 +- .../microbenchmarks/operator_inp_utils.py | 2 +- benchmarks/dynamo/runner.py | 4 +-- torch/_dynamo/backends/registry.py | 4 +-- torch/_dynamo/backends/tvm.py | 2 +- torch/_dynamo/debug_utils.py | 2 +- torch/_dynamo/eval_frame.py | 2 +- torch/_dynamo/guards.py | 4 +-- torch/_dynamo/logging.py | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/trace_rules.py | 12 +++---- torch/_dynamo/utils.py | 4 +-- torch/_dynamo/variables/builder.py | 4 +-- torch/_dynamo/variables/builtin.py | 8 ++--- torch/_dynamo/variables/functions.py | 2 +- torch/_dynamo/variables/nn_module.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_dynamo/variables/torch.py | 6 ++-- torch/_dynamo/variables/torch_function.py | 2 +- torch/_dynamo/variables/user_defined.py | 8 ++--- .../_aot_autograd/autograd_cache.py | 2 +- torch/_functorch/partitioners.py | 2 +- torch/_inductor/autotune_process.py | 2 +- torch/_inductor/codecache.py | 20 ++++++------ torch/_inductor/codegen/common.py | 4 +-- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 6 ++-- torch/_inductor/codegen/cuda/cuda_env.py | 2 +- torch/_inductor/codegen/cuda/cutlass_cache.py | 2 +- .../_inductor/codegen/cuda/cutlass_presets.py | 2 +- torch/_inductor/codegen/cuda/cutlass_utils.py | 4 +-- .../rocm/ck_tile_universal_gemm_template.py | 2 +- torch/_inductor/codegen/simd.py | 2 +- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/codegen/wrapper.py | 6 ++-- torch/_inductor/compile_fx.py | 4 +-- torch/_inductor/compile_fx_ext.py | 2 +- torch/_inductor/compiler_bisector.py | 2 +- torch/_inductor/cpp_builder.py | 32 +++++++++---------- torch/_inductor/cpu_vec_isa.py | 6 ++-- torch/_inductor/debug.py | 2 +- torch/_inductor/decomposition.py | 2 +- torch/_inductor/dtype_propagation.py | 2 +- torch/_inductor/fx_passes/binary_folding.py | 2 +- .../_inductor/fx_passes/freezing_patterns.py | 2 +- torch/_inductor/fx_passes/fuse_attention.py | 2 +- torch/_inductor/fx_passes/misc_patterns.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 4 +-- torch/_inductor/fx_passes/pad_mm.py | 4 +-- torch/_inductor/fx_passes/quantization.py | 2 +- torch/_inductor/kernel/mm.py | 4 +-- torch/_inductor/loop_body.py | 2 +- torch/_inductor/lowering.py | 2 +- torch/_inductor/pattern_matcher.py | 4 +-- torch/_inductor/runtime/compile_tasks.py | 2 +- torch/_inductor/runtime/hints.py | 2 +- torch/_inductor/select_algorithm.py | 12 +++---- torch/_inductor/utils.py | 22 ++++++------- torch/_logging/_internal.py | 2 +- torch/_prims/context.py | 4 +-- torch/_subclasses/fake_impls.py | 4 +-- torch/_subclasses/fake_tensor.py | 6 ++-- torch/_utils_internal.py | 2 +- torch/fx/experimental/symbolic_shapes.py | 2 +- torch/onnx/_internal/fx/patcher.py | 2 +- torch/overrides.py | 10 +++--- torch/utils/_helion.py | 4 +-- torch/utils/_sympy/interp.py | 2 +- torch/utils/_triton.py | 18 +++++------ 70 files changed, 157 insertions(+), 157 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 8138817337d..4918f57f3af 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -560,7 +560,7 @@ def nothing(f): return f -@functools.lru_cache(None) +@functools.cache def patch_torch_manual_seed(): """Make torch manual seed deterministic. Helps with accuracy testing.""" diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 36a212625f1..cad258ff227 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -135,7 +135,7 @@ def contains_tensor_types(type): ) -@functools.lru_cache(None) +@functools.cache def non_compute_operator(op): schema = op._schema diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index b27425b1283..08668fd4502 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -550,7 +550,7 @@ def build_summary(args): gh_fh.write(comment) -@functools.lru_cache(None) +@functools.cache def archive_data(archive_name): if archive_name is not None: prefix_match = re.search(r"\w+(?=_performance)", archive_name) @@ -570,7 +570,7 @@ def archive_data(archive_name): return day, prefix -@functools.lru_cache(None) +@functools.cache def default_archive_name(dtype): _, prefix = archive_data(None) return f"{prefix}_performance_{dtype}_{randint(100, 999)}" diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 01381aa66b8..79376b0e460 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -154,7 +154,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: return sorted(backends) -@functools.lru_cache(None) +@functools.cache def _lazy_import(): from .. import backends from ..utils import import_submodule @@ -168,7 +168,7 @@ def _lazy_import(): _discover_entrypoint_backends() -@functools.lru_cache(None) +@functools.cache def _discover_entrypoint_backends(): # importing here so it will pick up the mocked version in test_backends.py from importlib.metadata import entry_points diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 3a5b239183f..ab0097e314c 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -201,7 +201,7 @@ def has_tvm(): return False -@functools.lru_cache(None) +@functools.cache def llvm_target(): if sys.platform == "linux": cpuinfo = open("/proc/cpuinfo").read() diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 2e793bb4c7d..a23b58cedf2 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -247,7 +247,7 @@ class NNModuleToString: return model_str -@functools.lru_cache(None) # subprocess is expensive +@functools.cache # subprocess is expensive def _cuda_system_info_comment(): if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 88d40c6dafa..771d1071f76 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2013,7 +2013,7 @@ def _optimize_assert( class TorchPatcher: @staticmethod - @functools.lru_cache(None) + @functools.cache def patch(): # A better way to disable the following would be decorate the source # functions with @torch._disable_dynamo. However, this causes issues diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b2c4342bb11..5849dedf32e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -418,7 +418,7 @@ def from_numpy(a): # For user stack printing -@functools.lru_cache(None) +@functools.cache def uninteresting_files(): import torch._dynamo.external_utils import torch._dynamo.polyfills @@ -623,7 +623,7 @@ class GuardManagerType(enum.Enum): DICT_GUARD_MANAGER = 2 -@functools.lru_cache(None) +@functools.cache def code_framelocals_names_reversed_cached(code: types.CodeType): return list(reversed(code_framelocals_names(code))) diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 2d67665f5e9..18febf1377c 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -33,7 +33,7 @@ def get_loggers() -> list[logging.Logger]: # get_step_logger should be lazily called (i.e. at runtime, not at module-load time) # so that step numbers are initialized properly. e.g.: -# @functools.lru_cache(None) +# @functools.cache # def _step_logger(): # return get_step_logger(logging.getLogger(...)) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 21d88b7f1b0..8e089945068 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -210,7 +210,7 @@ class VariableTrackerCache: self.cache.clear() -@functools.lru_cache(None) +@functools.cache def _step_logger(): return torchdynamo_logging.get_step_logger(log) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index a0b2efb0a0b..eae8ee46e09 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -331,7 +331,7 @@ class TensorifyState: return len(cls.force_specializations) == 0 -@functools.lru_cache(None) +@functools.cache def _step_logger(): return torchdynamo_logging.get_step_logger(log) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 627b8b5e24c..a3b8bb7b976 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2896,7 +2896,7 @@ Generate the torch object - Dynamo tracing rule (the wrapping variable) map. """ -@functools.lru_cache(None) +@functools.cache def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: d: dict[Any, type[VariableTracker]] = {} for m in torch_name_rule_map: @@ -2945,7 +2945,7 @@ Get all torch.Tensor methods which are allowed to be in graph functions. """ -@functools.lru_cache(None) +@functools.cache def get_tensor_method(): disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} s = set() @@ -3467,7 +3467,7 @@ assert sorted(set(MOD_SKIPLIST)) == MOD_SKIPLIST MOD_SKIPLIST = set(MOD_SKIPLIST) -@functools.lru_cache(None) +@functools.cache def get_legacy_mod_inlinelist(): inlinelist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3476,7 +3476,7 @@ def get_legacy_mod_inlinelist(): return inlinelist -@functools.lru_cache(None) +@functools.cache def get_mod_inlinelist(): inlinelist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3485,7 +3485,7 @@ def get_mod_inlinelist(): return inlinelist -@functools.lru_cache(None) +@functools.cache def get_mod_skiplist(): skiplist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3739,7 +3739,7 @@ def is_torch_inline_allowed(filename): return any(filename.startswith(d) for d in get_mod_inlinelist()) -@functools.lru_cache(None) +@functools.cache def dynamo_dir(): import torch._dynamo diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1f35afeb90b..9300a604849 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2369,7 +2369,7 @@ def is_safe_constant(v): ) -@functools.lru_cache(None) +@functools.cache def common_constants(): return { # We zero-one specialize shapes, so specialize these constants @@ -3111,7 +3111,7 @@ seen_code_map = ExactWeakKeyDictionary() # return same dir unless user changes config between calls -@functools.lru_cache(None) +@functools.cache def _get_debug_dir(root_dir): dir_name = ( "run_" diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index afa1bc08307..a37988e4e6f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -475,7 +475,7 @@ class VariableBuilder: return cls._type_dispatch_impl(config.trace_numpy) @classmethod - @functools.lru_cache(None) + @functools.cache def _type_dispatch_impl(cls, trace_numpy): # NB: Careful not to close over self to avoid ref cycle from lru_cache entries = [ @@ -576,7 +576,7 @@ class VariableBuilder: return self.tx.output.side_effects.track_mutable(value, result) @classmethod - @functools.lru_cache(None) + @functools.cache def _id_dispatch( cls, ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 374072f4082..522c3a38554 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -169,7 +169,7 @@ class BuiltinVariable(VariableTracker): return cls(value, source=source) @staticmethod - @functools.lru_cache(None) + @functools.cache def _constant_fold_functions(): fns = { abs, @@ -239,7 +239,7 @@ class BuiltinVariable(VariableTracker): return self.fn in self._constant_fold_functions() @staticmethod - @functools.lru_cache(None) + @functools.cache def _fx_graph_functions(): fns = { operator.abs, @@ -285,7 +285,7 @@ class BuiltinVariable(VariableTracker): return fns @staticmethod - @functools.lru_cache(None) + @functools.cache def _binops() -> dict[ Callable[..., object], tuple[list[str], Callable[..., object]] ]: @@ -324,7 +324,7 @@ class BuiltinVariable(VariableTracker): return fns @staticmethod - @functools.lru_cache(None) + @functools.cache def _binop_handlers(): # Multiple dispatch mechanism defining custom binop behavior for certain type # combinations. Handlers are attempted in order, and will be used if the type checks diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 6ac803997ca..d569aa8e800 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1805,7 +1805,7 @@ class PolyfilledFunctionVariable(VariableTracker): } @classmethod - @functools.lru_cache(None) + @functools.cache def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 6f3f2383cf7..4a63d95c978 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -900,7 +900,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): self.nn_module_stack_source = source @staticmethod - @functools.lru_cache(None) + @functools.cache def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0cc771c61c3..39413c58963 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1040,7 +1040,7 @@ class TensorVariable(VariableTracker): return wrap_fx_proxy(tx, proxy) @staticmethod - @functools.lru_cache(None) + @functools.cache def _warn_capture_scalar_outputs(): user_stack = torch._guards.TracingContext.extract_stack() user_stack_formatted = "".join(traceback.format_list(user_stack)) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e44bd2a8e2a..7ac0c224978 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -170,7 +170,7 @@ constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need constant_fold_functions = dict.fromkeys(constant_fold_functions) -@functools.lru_cache(None) +@functools.cache def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: # Defined as a function to avoid circular import like torch.onnx return { @@ -197,7 +197,7 @@ dispatch_key_set_functions = { } -@functools.lru_cache(None) +@functools.cache def get_overridable_functions(): from itertools import chain @@ -432,7 +432,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): return self.value @staticmethod - @functools.lru_cache(None) + @functools.cache def _get_handlers(): """Build a dict from function -> method to handle it so that we are O(1) in terms of the number of function with special handling.""" diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ddacee127f1..26e3ee8aa0c 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -199,7 +199,7 @@ banned_attrs = [ ] -@functools.lru_cache(None) +@functools.cache def get_prev_stack_var_name(): from ..bytecode_transformation import unique_id diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index e27cfed6f6a..a446cf4aeda 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -149,7 +149,7 @@ class UserDefinedClassVariable(UserDefinedVariable): return f"{self.__class__.__name__}({self.value})" @staticmethod - @functools.lru_cache(None) + @functools.cache def _constant_fold_classes(): return { torch.device, @@ -159,7 +159,7 @@ class UserDefinedClassVariable(UserDefinedVariable): } @staticmethod - @functools.lru_cache(None) + @functools.cache def _in_graph_classes(): _in_graph_class_list = { torch.Tensor, @@ -177,7 +177,7 @@ class UserDefinedClassVariable(UserDefinedVariable): return set(tensortype_to_dtype.keys()) | _in_graph_class_list @staticmethod - @functools.lru_cache(None) + @functools.cache def supported_c_new_functions(): exceptions = [ getattr(builtins, name).__new__ @@ -843,7 +843,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): ) @staticmethod - @functools.lru_cache(None) + @functools.cache def _supported_random_functions(): fns = { random.random, diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 9f4adde6c93..fb51cf59871 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -1229,7 +1229,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): remote_cache.put(key, cache_data) @staticmethod - @functools.lru_cache(None) + @functools.cache def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: """ Attempts to load the remote cache, returns None on error. diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 60d125116a6..d3d2aaa34de 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1052,7 +1052,7 @@ def _count_ops(graph: fx.Graph): log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) -@functools.lru_cache(None) +@functools.cache def pointwise_ops(): ops = [] for attr_name in dir(torch.ops.aten): diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index d605252def3..2cd56300329 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -874,7 +874,7 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): return f"{self.kernel_name=}" -@functools.lru_cache(None) +@functools.cache def get_tuning_process_pool() -> TuningProcessPool: pool = TuningProcessPool() atexit.register(pool.shutdown) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index ad4427e4ecf..723b86dd28c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -176,7 +176,7 @@ def get_kernel_bin_format(device: str) -> str: return "" -@functools.lru_cache(None) +@functools.cache def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: return ( Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) @@ -187,7 +187,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: class CacheBase: @staticmethod - @functools.lru_cache(None) + @functools.cache def get_system() -> dict[str, Any]: try: from triton.compiler.compiler import triton_key @@ -226,7 +226,7 @@ class CacheBase: @staticmethod @clear_on_fresh_inductor_cache - @functools.lru_cache(None) + @functools.cache def get_local_cache_path() -> Path: return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) @@ -280,7 +280,7 @@ class LocalCache(CacheBase): class PersistentCache(CacheBase): - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def get_global_cache(self) -> dict[str, Any]: global_cache_path = self.get_global_cache_path() if global_cache_path is None or not global_cache_path.is_file(): @@ -1567,7 +1567,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]): pass -@functools.lru_cache(None) +@functools.cache def split_aot_inductor_output_path(path: str) -> tuple[str, str]: """Returns the path where the AOT Inductor compiled kernels are stored.""" if path.endswith(".so"): @@ -2283,7 +2283,7 @@ _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") _HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") -@functools.lru_cache(None) +@functools.cache def _precompile_header( header: str, hashable_cmd_line: str, @@ -2969,7 +2969,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache): return glue_code @classmethod - @functools.lru_cache(None) + @functools.cache def config_hash(cls) -> str: command_gen = CppBuilder( name="O", @@ -3013,7 +3013,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache): raise RuntimeError(errmsg) @staticmethod - @functools.lru_cache(None) + @functools.cache def find_libautoschedule(name: str) -> str: sofile = f"libautoschedule_{name.lower()}.so" if "HALIDE_LIB" in os.environ: @@ -3026,7 +3026,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache): return HalideCodeCache._search_for_file(sofile, errmsg) @staticmethod - @functools.lru_cache(None) + @functools.cache def find_header(name: str) -> str: if "HALIDE_INCLUDE" in os.environ: path = os.path.join(os.environ["HALIDE_INCLUDE"], name) @@ -3300,7 +3300,7 @@ class PyCodeCache: cls.modules_no_attr.clear() @classmethod - @functools.lru_cache(None) + @functools.cache def stack_frames_for_code( cls, path: str, lineno: int ) -> Optional[list[dict[str, Any]]]: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 9ef5e24c2cb..a32cc755f95 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -441,7 +441,7 @@ def get_wrapper_codegen_for_device( return None -@functools.lru_cache(None) +@functools.cache def init_backend_registration() -> None: from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu @@ -2223,7 +2223,7 @@ class OptimizationContext: ops_name: str = "" -@functools.lru_cache(None) +@functools.cache def jinja2_env() -> Any: try: import jinja2 diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 57a8bd188a8..5aa075c1016 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -86,7 +86,7 @@ from .cpp_utils import ( _IS_WINDOWS = sys.platform == "win32" -@functools.lru_cache(None) +@functools.cache def get_export_declaration(): return "__declspec(dllexport)" if _IS_WINDOWS else "" diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 5eeb3e6e764..fd6a34a134b 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -274,12 +274,12 @@ class CppWrapperCpu(PythonWrapperCodegen): ): code = self.prefix - @functools.lru_cache(None) + @functools.cache def sizeof(name): self.codegen_input_size_var_decl(code, name) return f"{name}_size" - @functools.lru_cache(None) + @functools.cache def strideof(name): self.codegen_input_stride_var_decl(code, name) return f"{name}_stride" @@ -1469,7 +1469,7 @@ class CppWrapperCpu(PythonWrapperCodegen): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, diff --git a/torch/_inductor/codegen/cuda/cuda_env.py b/torch/_inductor/codegen/cuda/cuda_env.py index 95be434e03b..27f51b52130 100644 --- a/torch/_inductor/codegen/cuda/cuda_env.py +++ b/torch/_inductor/codegen/cuda/cuda_env.py @@ -40,6 +40,6 @@ def get_cuda_version() -> Optional[str]: return None -@functools.lru_cache(None) +@functools.cache def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py index c87b1878bc1..5fa17779dec 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -48,7 +48,7 @@ def _generate_config_filename(request_key: str) -> str: @clear_on_fresh_inductor_cache -@functools.lru_cache(None) +@functools.cache def maybe_fetch_ops() -> Optional[list[Any]]: """ Fetch ops from databases. diff --git a/torch/_inductor/codegen/cuda/cutlass_presets.py b/torch/_inductor/codegen/cuda/cutlass_presets.py index bc97c22e247..f0888e2ef29 100644 --- a/torch/_inductor/codegen/cuda/cutlass_presets.py +++ b/torch/_inductor/codegen/cuda/cutlass_presets.py @@ -5,7 +5,7 @@ import torch from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch -@functools.lru_cache(None) +@functools.cache def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: """ Generate cutlass presets for the given CUDA arch. diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 48a6dbe6d21..804487a2515 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -56,7 +56,7 @@ def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: return content -@functools.lru_cache(None) +@functools.cache def try_import_cutlass() -> bool: """ We want to support three ways of passing in CUTLASS: @@ -251,7 +251,7 @@ class CUTLASSArgs: @clear_on_fresh_inductor_cache -@functools.lru_cache(None) +@functools.cache def _gen_ops_cached(arch, version) -> dict[Any, Any]: # Note: Cache needs to be specific for cuda architecture and version diff --git a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py index 466648be3f8..5862534ce6c 100644 --- a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -95,7 +95,7 @@ class CKTileGemmOperation: return asdict(self).items() -@functools.lru_cache(None) +@functools.cache def ops(): """ Generate the supported instance dataclasses diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 3d4e61953c5..8b07374faf0 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -416,7 +416,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): self.code_hash: Optional[str] = None # define this in a closure to make cache local to object - @functools.lru_cache(None) + @functools.cache def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 5fe5506dc55..f10a5f6d217 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1311,7 +1311,7 @@ class TritonKernelOverrides(TritonOverrides): self._setup_libdevice_routing() @classmethod - @functools.lru_cache(None) + @functools.cache def _setup_libdevice_routing(cls): """Set up routing to libdevice implementations for fp64 inputs.""" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c724a8ee8da..4aa9d5bf221 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -915,7 +915,7 @@ class PythonWrapperCodegen(CodeGen): self.write_get_raw_stream ) - @functools.lru_cache(None) + @functools.cache def add_import_once(line: str) -> None: self.imports.writeline(line) if config.triton.autotune_at_compile_time: @@ -1625,12 +1625,12 @@ class PythonWrapperCodegen(CodeGen): ): code = self.prefix - @functools.lru_cache(None) + @functools.cache def sizeof(name): code.writeline(f"{name}_size = {name}.size()") return f"{name}_size" - @functools.lru_cache(None) + @functools.cache def strideof(name): code.writeline(f"{name}_stride = {name}.stride()") return f"{name}_stride" diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d323846d758..353dc43b881 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -238,12 +238,12 @@ def record_original_output_strides(gm: GraphModule) -> None: output_node.meta["original_output_strides"] = output_strides -@functools.lru_cache(None) +@functools.cache def _step_logger() -> Callable[..., None]: return dynamo_logging.get_step_logger(log) -@functools.lru_cache(None) +@functools.cache def _warn_tf32_disabled() -> None: if ( torch.cuda.is_available() diff --git a/torch/_inductor/compile_fx_ext.py b/torch/_inductor/compile_fx_ext.py index f165aa8ab9f..7fd976a05ed 100644 --- a/torch/_inductor/compile_fx_ext.py +++ b/torch/_inductor/compile_fx_ext.py @@ -614,7 +614,7 @@ class _OutOfProcessFxCompile(_SerializedFxCompile): # And forward our collected logs. The cache is cleared when the outer # function exits. - @functools.lru_cache(None) + @functools.cache def getLogger(name: str) -> logging.Logger: return logging.getLogger(name) diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index eebff4b566c..5cec2020c9f 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -79,7 +79,7 @@ def reset_counters() -> None: call_counter_debug_info.clear() -@functools.lru_cache(None) +@functools.cache def get_env_val(env_str: str) -> Optional[str]: return os.environ.get(env_str, None) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index ce7cc09eb6d..a634d85a4b8 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -127,7 +127,7 @@ def install_gcc_via_conda() -> str: return cxx_path -@functools.lru_cache(None) +@functools.cache def check_compiler_exist_windows(compiler: str) -> None: """ Check if compiler is ready, in case end user not activate MSVC environment. @@ -200,13 +200,13 @@ def convert_cubin_to_obj( return obj_file -@functools.lru_cache(None) +@functools.cache def _is_apple_clang(cpp_compiler: str) -> bool: version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") return "Apple" in version_string.splitlines()[0] -@functools.lru_cache(None) +@functools.cache def _is_clang(cpp_compiler: str) -> bool: # Mac OS apple clang maybe named as gcc, need check compiler info. if sys.platform == "darwin": @@ -221,7 +221,7 @@ def _is_clang(cpp_compiler: str) -> bool: return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) -@functools.lru_cache(None) +@functools.cache def _is_gcc(cpp_compiler: str) -> bool: # Since "clang++" ends with "g++", the regex match below would validate on it. if _is_clang(cpp_compiler): @@ -229,7 +229,7 @@ def _is_gcc(cpp_compiler: str) -> bool: return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler)) -@functools.lru_cache(None) +@functools.cache def _is_msvc_cl(cpp_compiler: str) -> bool: if not _IS_WINDOWS: return False @@ -247,7 +247,7 @@ def _is_msvc_cl(cpp_compiler: str) -> bool: return False -@functools.lru_cache(None) +@functools.cache def _is_intel_compiler(cpp_compiler: str) -> bool: def _check_minimal_version(compiler_version: TorchVersion) -> None: """ @@ -291,32 +291,32 @@ def _is_intel_compiler(cpp_compiler: str) -> bool: return False -@functools.lru_cache(None) +@functools.cache def is_gcc() -> bool: return _is_gcc(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_clang() -> bool: return _is_clang(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_intel_compiler() -> bool: return _is_intel_compiler(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_apple_clang() -> bool: return _is_apple_clang(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_msvc_cl() -> bool: return _is_msvc_cl(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def get_compiler_version_info(compiler: str) -> str: env = os.environ.copy() env["LC_ALL"] = "C" # Don't localize output @@ -885,7 +885,7 @@ def _get_python_related_args() -> tuple[list[str], list[str]]: return python_include_dirs, python_lib_path -@functools.lru_cache(None) +@functools.cache def is_conda_llvm_openmp_installed() -> bool: try: command = "conda list llvm-openmp --json" @@ -895,7 +895,7 @@ def is_conda_llvm_openmp_installed() -> bool: return False -@functools.lru_cache(None) +@functools.cache def homebrew_libomp() -> tuple[bool, str]: try: # check if `brew` is installed @@ -916,7 +916,7 @@ def homebrew_libomp() -> tuple[bool, str]: return False, "" -@functools.lru_cache(None) +@functools.cache def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: try: output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( @@ -930,7 +930,7 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: pass -@functools.lru_cache(None) +@functools.cache def perload_icx_libomp_win(cpp_compiler: str) -> None: def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: try: diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 6ef28679776..fe759266533 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -146,7 +146,7 @@ cdll.LoadLibrary("__lib_path__") def __bool__(self) -> bool: return self.__bool__impl(config.cpp.vec_isa_ok) - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def __bool__impl(self, vec_isa_ok) -> bool: if vec_isa_ok is not None: return vec_isa_ok @@ -241,7 +241,7 @@ extern "C" void __amx_chk_kernel() { } """ - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def __bool__(self) -> bool: if super().__bool__(): if config.is_fbcode(): @@ -380,7 +380,7 @@ def get_isa_from_cpu_capability( # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. -@functools.lru_cache(None) +@functools.cache def valid_vec_isa_list() -> list[VecISA]: isa_list: list[VecISA] = [] if sys.platform == "darwin" and platform.processor() == "arm": diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index ee328e8b560..d6ca5ad241e 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -50,7 +50,7 @@ BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] -@functools.lru_cache(None) +@functools.cache def has_dot() -> bool: return shutil.which("dot") is not None diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index bc0ee2979aa..6c24d551c34 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -856,7 +856,7 @@ def miopen_batch_norm( ) -@functools.lru_cache(None) +@functools.cache def fast_random_decomps() -> dict[Any, Callable[..., Any]]: return {**decompositions, **extra_random_decomps} diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 811ae9982d2..5f99d83e07e 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -29,7 +29,7 @@ DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue] # So first decompose CSEVars -> tuple before calling this -@functools.lru_cache(None) +@functools.cache def get_promoted_dtype( *args: Sequence[tuple[torch.dtype, bool]], type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index c64f1309319..d2ad3e1c8f9 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -83,7 +83,7 @@ def recover_original_precision_folded_computation_ops(gm): _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] -@functools.lru_cache(None) +@functools.cache def binary_folding_init(): _conv_args = [Arg() for _ in range(9)] _addmm_args = [Arg() for _ in range(3)] diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 8b6437fc258..4ca307825f1 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -119,7 +119,7 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true): ) -@functools.lru_cache(None) +@functools.cache def addmm_patterns_init(): device = next( (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 488228cb511..b9b2ea2c888 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -956,7 +956,7 @@ def _get_sfdp_patterns(): ) -@functools.lru_cache(None) +@functools.cache def _sfdp_init(): for key, register_replacement_kwargs in _get_sfdp_patterns(): gen_register_replacement(key, **register_replacement_kwargs) diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index b4e0f1f3502..d2c8068f130 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -12,7 +12,7 @@ from ..pattern_matcher import fwd_only, register_replacement aten = torch.ops.aten -@functools.lru_cache(None) +@functools.cache def _misc_patterns_init(): from .joint_graph import patterns as joint_graph_patterns from .post_grad import pass_patterns as post_grad_patterns_all diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 9e69f96d27f..9e415e2ad0d 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1398,7 +1398,7 @@ if torch._C._has_mkldnn: user_node.replace_all_uses_with(node) gm.graph.erase_node(user_node) - @functools.lru_cache(None) + @functools.cache def _mkldnn_fusion_init(): # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. # Otherwise even the matmul or innerproduct can not be accelerated with acl @@ -1414,7 +1414,7 @@ if torch._C._has_mkldnn: _register_quantization_lowerings() _register_woq_lowerings() - @functools.lru_cache(None) + @functools.cache def _mkldnn_weight_pack_init(): if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): _register_weight_pack_pass() diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 655a0e44d24..10ca1c4dae9 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -247,7 +247,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: return arithmetic_intensity > machine_balance -@functools.lru_cache(None) +@functools.cache def get_pad_cache() -> torch._inductor.codecache.LocalCache: return torch._inductor.codecache.LocalCache() @@ -851,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: ) -@functools.lru_cache(None) +@functools.cache def _pad_mm_init() -> None: from .joint_graph import patterns diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 88c5f8497ac..4fe32ebf909 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -3478,7 +3478,7 @@ def _register_qlinear_binary_fusion(): ) -@functools.lru_cache(None) +@functools.cache def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 _register_dequant_promotion() diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 0ef8bfc1f74..aa227e97241 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -511,7 +511,7 @@ scaled_mm_device_tma_template = TritonTemplate( # prevent duplication registration of extern functions -@functools.lru_cache(None) +@functools.cache def lazy_register_extern_choice(fn): return ExternKernelChoice(fn) @@ -1175,7 +1175,7 @@ def tuned_scaled_mm( return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) -@functools.lru_cache(None) +@functools.cache def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) return props.major <= 7 diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index e3141dbc9c2..ffcf431c0cb 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -36,7 +36,7 @@ T = TypeVar("T") class InterpreterShim(torch.fx.Interpreter): @staticmethod - @functools.lru_cache(None) + @functools.cache def _dummy_gm(): return torch.fx.symbolic_trace(identity) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 97cdf22cf9d..9ff083a6a81 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1879,7 +1879,7 @@ def fallback_handler(kernel, add_to_fallback_set=True): return handler -@functools.lru_cache(None) +@functools.cache def _warn_complex_not_supported(): warnings.warn( "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 4a48420e151..5520da3a6fe 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1010,7 +1010,7 @@ class PatternPrettyPrinter: self.memoized_objs_pp: dict[PatternExpr, str] = {} @staticmethod - @functools.lru_cache(None) + @functools.cache def run(obj: PatternExpr, output_name: str = "output") -> str: """ Serializes obj to python code with obj written out to `output_name` @@ -2195,7 +2195,7 @@ def stable_topological_sort(graph: torch.fx.Graph) -> None: def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: """Wrapper around lazy init functions in fx_passes/""" - @functools.lru_cache(None) + @functools.cache @functools.wraps(fn) def lazy_init() -> Any: counters_ref = counters["inductor"].copy() diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 78be2e3787c..67140369faa 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -34,7 +34,7 @@ def _reload_python_module( return mod -@functools.lru_cache(None) +@functools.cache def _set_triton_ptxas_path() -> None: if os.environ.get("TRITON_PTXAS_PATH") is not None: return diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index f224217db22..e559eaa1a31 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -136,7 +136,7 @@ class DeviceProperties(typing.NamedTuple): warp_size: Optional[int] = None @classmethod - @functools.lru_cache(None) + @functools.cache def create(cls, device) -> DeviceProperties: import torch from torch._dynamo.device_interface import get_interface_for_device diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index af5a08dad26..cde26c3ccdf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1124,7 +1124,7 @@ class TritonTemplateKernel(TritonKernel): ] -@functools.lru_cache(None) +@functools.cache def _jinja2_env(): try: import jinja2 @@ -1726,7 +1726,7 @@ class ExternKernelChoice: def call_name(self): return f"extern_kernels.{self.name}" - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def hash_key(self): fn = self.to_callable() parts = [ @@ -1933,7 +1933,7 @@ class ExternKernelCaller(ChoiceCaller): return f"extern_{self.choice.name}" -@functools.lru_cache(None) +@functools.cache def get_mm_log_filename() -> Optional[str]: mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) if not mm_file_name: @@ -2052,7 +2052,7 @@ class NoValidChoicesError(RuntimeError): pass -@functools.lru_cache(None) +@functools.cache def get_num_workers() -> int: if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) @@ -2194,7 +2194,7 @@ class AlgorithmSelectorCache(PersistentCache): # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. return choices[0].output_node() - @functools.lru_cache(None) + @functools.cache def make_benchmark_fn(): return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) @@ -2506,7 +2506,7 @@ class AlgorithmSelectorCache(PersistentCache): future.add_done_callback(on_complete) futures[future] = c - @functools.lru_cache(None) + @functools.cache @restore_stdout_stderr() def wait_on_futures(): log.debug("Waiting on futures") diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 872c75480e9..372dd3b75e3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -91,7 +91,7 @@ T = TypeVar("T") # defines here before import torch._dynamo is for avoiding circular import # when get_gpu_type is imported from dynamo -@functools.lru_cache(None) +@functools.cache def get_gpu_type() -> str: avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] assert len(avail_gpus) <= 1 @@ -338,7 +338,7 @@ def do_bench_using_profiling( return res -@functools.lru_cache(None) +@functools.cache def has_torchvision_roi_align() -> bool: try: from torchvision.ops import roi_align # noqa: F401 @@ -1384,7 +1384,7 @@ class DelayReplaceLine(DeferredLineBase): return DelayReplaceLine(self.key, self.value_fn, line) -@functools.lru_cache(None) +@functools.cache def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: if isinstance(index_or_device, torch.device): device = index_or_device @@ -1599,7 +1599,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) -@functools.lru_cache(None) +@functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: @@ -1652,12 +1652,12 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: return best_splits[:k_splits_limit] -@functools.lru_cache(None) +@functools.cache def _rocm_native_device_arch_name(device: str) -> str: return torch.cuda.get_device_properties(device).gcnArchName -@functools.lru_cache(None) +@functools.cache def try_import_ck_lib() -> tuple[ Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] ]: @@ -2106,7 +2106,7 @@ def parallel_num_threads() -> int: return threads -@functools.lru_cache(None) +@functools.cache def get_backend_num_stages() -> int: from .runtime.triton_helpers import get_backend_options @@ -2114,7 +2114,7 @@ def get_backend_num_stages() -> int: return options.get("num_stages", 2 if torch.version.hip else 3) -@functools.lru_cache(None) +@functools.cache def get_device_tflops(dtype: torch.dtype) -> int: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops @@ -2142,7 +2142,7 @@ def get_device_tflops(dtype: torch.dtype) -> int: return get_max_simd_tflops(torch.float32) -@functools.lru_cache(None) +@functools.cache def get_gpu_dram_gbps() -> int: from triton.testing import get_dram_gbps @@ -2862,7 +2862,7 @@ def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: ) -@functools.lru_cache(None) +@functools.cache def boolean_ops() -> tuple[str, ...]: return ( "isinf", @@ -3051,7 +3051,7 @@ class TritonAttrsDescriptorVersion(enum.Enum): V4_DICT = 4 # a raw dict -@functools.lru_cache(None) +@functools.cache def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion: if importlib.util.find_spec("triton") is None: return TritonAttrsDescriptorVersion.V0_NO_TRITON diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index c6730f36ff9..3821218cefe 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1143,7 +1143,7 @@ class LazyTraceHandler(logging.StreamHandler): super().emit(record) -@functools.lru_cache(None) +@functools.cache def warning_once(logger_obj, *args, **kwargs) -> None: """ This function is similar to `logger.warning()`, but will emit the warning with the same message only once diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 36cb40e7916..97e6a274732 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -15,7 +15,7 @@ import torch.overrides from torch._prims_common import torch_function_passthrough -@functools.lru_cache(None) +@functools.cache def torch_to_refs_map(): """ Mapping of torch API functions to torch._refs functions. @@ -70,7 +70,7 @@ def torch_to_refs_map(): return r -@functools.lru_cache(None) +@functools.cache def all_prims(): """ Set of all prim functions, e.g., torch._prims.add in all_prims() diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index eaa64ae97b0..7f514ef9976 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -114,7 +114,7 @@ def contains_tensor_types(type): ) -@functools.lru_cache(None) +@functools.cache def _is_tensor_constructor(func: OpOverload): assert isinstance(func, OpOverload) schema = func._schema @@ -1077,7 +1077,7 @@ def fast_detach(fake_mode, x): return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) -@functools.lru_cache(None) +@functools.cache def get_fast_op_impls(): import torch._refs diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 41e353c4679..b9007fe937f 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -233,7 +233,7 @@ def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: return None -@functools.lru_cache(None) +@functools.cache def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: return torch._C._SchemaInfo(func._schema) @@ -243,7 +243,7 @@ def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: # torch/_decomp/decompositions.py. # decomps are used for aot autograd tracing so we would like to unify on their # implementation and add additional testing to them -@functools.lru_cache(None) +@functools.cache def torch_decomp_decompositions(func: OpOverload) -> bool: from torch._decomp import decomposition_table @@ -511,7 +511,7 @@ class FakeTensorConverter: return out -@functools.lru_cache(None) +@functools.cache def init_gpu_context(device: torch.device) -> None: # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first if torch.cuda.is_available() or torch.xpu.is_available(): diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 89fbd678728..b5da1941450 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -210,7 +210,7 @@ def is_fb_unit_test() -> bool: return False -@functools.lru_cache(None) +@functools.cache def max_clock_rate(): if not torch.version.hip: from triton.testing import nvsmi diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index cd1f99233e1..d7dfdf43b12 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -7379,7 +7379,7 @@ class ShapeEnv: # Don't track this one. (Because this cache is inside this function the # cache only lasts for the invocation of this function call) - @functools.lru_cache(None) + @functools.cache def compute_concrete_val() -> sympy.Basic: if hint is None: # This is only ever called for expressions WITHOUT unbacked diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index f8a52efda2c..6c9724e9f5a 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # TODO: Remove after https://github.com/huggingface/safetensors/pull/318 -@functools.lru_cache(None) +@functools.cache def has_safetensors_and_transformers(): try: # safetensors is not an exporter requirement, but needed for some huggingface models diff --git a/torch/overrides.py b/torch/overrides.py index 67e079d07db..f4edecd664b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -98,7 +98,7 @@ def _disable_user_warnings( return wrapper -@functools.lru_cache(None) +@functools.cache @_disable_user_warnings def get_ignored_functions() -> set[Callable]: """ @@ -378,7 +378,7 @@ def get_ignored_functions() -> set[Callable]: } -@functools.lru_cache(None) +@functools.cache def get_default_nowrap_functions() -> set[Callable]: """ Return public functions that do not wrap in a subclass when invoked by @@ -404,7 +404,7 @@ def get_default_nowrap_functions() -> set[Callable]: } -@functools.lru_cache(None) +@functools.cache @_disable_user_warnings def get_testing_overrides() -> dict[Callable, Callable]: """Return a dict containing dummy overrides for all overridable functions @@ -1808,7 +1808,7 @@ has_torch_function_variadic = _add_docstr( ) -@functools.lru_cache(None) +@functools.cache def _get_overridable_functions() -> tuple[ dict[Any, list[Callable]], dict[Callable, str] ]: @@ -1929,7 +1929,7 @@ def resolve_name(f): return _get_overridable_functions()[1].get(f) -@functools.lru_cache(None) +@functools.cache def _get_tensor_methods() -> set[Callable]: """Returns a set of the overridable methods on ``torch.Tensor``""" overridable_funcs = get_overridable_functions() diff --git a/torch/utils/_helion.py b/torch/utils/_helion.py index 624d1f81c8f..6d30832cf3f 100644 --- a/torch/utils/_helion.py +++ b/torch/utils/_helion.py @@ -3,7 +3,7 @@ import functools from torch.utils._triton import has_triton -@functools.lru_cache(None) +@functools.cache def has_helion_package() -> bool: try: import helion # type: ignore[import-untyped, import-not-found] # noqa: F401 @@ -12,6 +12,6 @@ def has_helion_package() -> bool: return True -@functools.lru_cache(None) +@functools.cache def has_helion() -> bool: return has_helion_package() and has_triton() diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 396d1d46d28..3b020b5fabb 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -51,7 +51,7 @@ log = logging.getLogger(__name__) # TODO: Dedupe this with SYMPY_INTERP -@functools.lru_cache(None) +@functools.cache def handlers(): # TODO add CeilDiv (it doesn't appear in the index_expr) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 374afebe823..798e0d6dee2 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -3,7 +3,7 @@ import hashlib from typing import Any -@functools.lru_cache(None) +@functools.cache def has_triton_package() -> bool: try: from triton.compiler.compiler import triton_key @@ -15,7 +15,7 @@ def has_triton_package() -> bool: return False -@functools.lru_cache(None) +@functools.cache def _device_supports_tma() -> bool: import torch @@ -26,7 +26,7 @@ def _device_supports_tma() -> bool: ) -@functools.lru_cache(None) +@functools.cache def has_triton_experimental_host_tma() -> bool: if has_triton_package(): if _device_supports_tma(): @@ -43,7 +43,7 @@ def has_triton_experimental_host_tma() -> bool: return False -@functools.lru_cache(None) +@functools.cache def has_triton_tensor_descriptor_host_tma() -> bool: if has_triton_package(): if _device_supports_tma(): @@ -59,12 +59,12 @@ def has_triton_tensor_descriptor_host_tma() -> bool: return False -@functools.lru_cache(None) +@functools.cache def has_triton_tma() -> bool: return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma() -@functools.lru_cache(None) +@functools.cache def has_triton_tma_device() -> bool: if has_triton_package(): import torch @@ -87,7 +87,7 @@ def has_triton_tma_device() -> bool: return False -@functools.lru_cache(None) +@functools.cache def has_triton() -> bool: if not has_triton_package(): return False @@ -121,7 +121,7 @@ def has_triton() -> bool: return is_device_compatible_with_triton() -@functools.lru_cache(None) +@functools.cache def triton_backend() -> Any: from triton.compiler.compiler import make_backend from triton.runtime.driver import driver @@ -130,7 +130,7 @@ def triton_backend() -> Any: return make_backend(target) -@functools.lru_cache(None) +@functools.cache def triton_hash_with_backend() -> str: from triton.compiler.compiler import triton_key