mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613 Approved by: https://github.com/ezyang ghstack dependencies: #155612
This commit is contained in:
parent
f80a61adf5
commit
d1947a8707
|
|
@ -560,7 +560,7 @@ def nothing(f):
|
|||
return f
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def patch_torch_manual_seed():
|
||||
"""Make torch manual seed deterministic. Helps with accuracy testing."""
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def contains_tensor_types(type):
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def non_compute_operator(op):
|
||||
schema = op._schema
|
||||
|
||||
|
|
|
|||
|
|
@ -550,7 +550,7 @@ def build_summary(args):
|
|||
gh_fh.write(comment)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def archive_data(archive_name):
|
||||
if archive_name is not None:
|
||||
prefix_match = re.search(r"\w+(?=_performance)", archive_name)
|
||||
|
|
@ -570,7 +570,7 @@ def archive_data(archive_name):
|
|||
return day, prefix
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def default_archive_name(dtype):
|
||||
_, prefix = archive_data(None)
|
||||
return f"{prefix}_performance_{dtype}_{randint(100, 999)}"
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
|
|||
return sorted(backends)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _lazy_import():
|
||||
from .. import backends
|
||||
from ..utils import import_submodule
|
||||
|
|
@ -168,7 +168,7 @@ def _lazy_import():
|
|||
_discover_entrypoint_backends()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _discover_entrypoint_backends():
|
||||
# importing here so it will pick up the mocked version in test_backends.py
|
||||
from importlib.metadata import entry_points
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ def has_tvm():
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def llvm_target():
|
||||
if sys.platform == "linux":
|
||||
cpuinfo = open("/proc/cpuinfo").read()
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ class NNModuleToString:
|
|||
return model_str
|
||||
|
||||
|
||||
@functools.lru_cache(None) # subprocess is expensive
|
||||
@functools.cache # subprocess is expensive
|
||||
def _cuda_system_info_comment():
|
||||
if not torch.cuda.is_available():
|
||||
return "# torch.cuda.is_available()==False, no GPU info collected\n"
|
||||
|
|
|
|||
|
|
@ -2013,7 +2013,7 @@ def _optimize_assert(
|
|||
|
||||
class TorchPatcher:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def patch():
|
||||
# A better way to disable the following would be decorate the source
|
||||
# functions with @torch._disable_dynamo. However, this causes issues
|
||||
|
|
|
|||
|
|
@ -418,7 +418,7 @@ def from_numpy(a):
|
|||
|
||||
|
||||
# For user stack printing
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def uninteresting_files():
|
||||
import torch._dynamo.external_utils
|
||||
import torch._dynamo.polyfills
|
||||
|
|
@ -623,7 +623,7 @@ class GuardManagerType(enum.Enum):
|
|||
DICT_GUARD_MANAGER = 2
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def code_framelocals_names_reversed_cached(code: types.CodeType):
|
||||
return list(reversed(code_framelocals_names(code)))
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def get_loggers() -> list[logging.Logger]:
|
|||
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
|
||||
# so that step numbers are initialized properly. e.g.:
|
||||
|
||||
# @functools.lru_cache(None)
|
||||
# @functools.cache
|
||||
# def _step_logger():
|
||||
# return get_step_logger(logging.getLogger(...))
|
||||
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class VariableTrackerCache:
|
|||
self.cache.clear()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _step_logger():
|
||||
return torchdynamo_logging.get_step_logger(log)
|
||||
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ class TensorifyState:
|
|||
return len(cls.force_specializations) == 0
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _step_logger():
|
||||
return torchdynamo_logging.get_step_logger(log)
|
||||
|
||||
|
|
|
|||
|
|
@ -2896,7 +2896,7 @@ Generate the torch object - Dynamo tracing rule (the wrapping variable) map.
|
|||
"""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
||||
d: dict[Any, type[VariableTracker]] = {}
|
||||
for m in torch_name_rule_map:
|
||||
|
|
@ -2945,7 +2945,7 @@ Get all torch.Tensor methods which are allowed to be in graph functions.
|
|||
"""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_tensor_method():
|
||||
disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"}
|
||||
s = set()
|
||||
|
|
@ -3467,7 +3467,7 @@ assert sorted(set(MOD_SKIPLIST)) == MOD_SKIPLIST
|
|||
MOD_SKIPLIST = set(MOD_SKIPLIST)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_legacy_mod_inlinelist():
|
||||
inlinelist = {
|
||||
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
|
|
@ -3476,7 +3476,7 @@ def get_legacy_mod_inlinelist():
|
|||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_mod_inlinelist():
|
||||
inlinelist = {
|
||||
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
|
|
@ -3485,7 +3485,7 @@ def get_mod_inlinelist():
|
|||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_mod_skiplist():
|
||||
skiplist = {
|
||||
_as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
|
|
@ -3739,7 +3739,7 @@ def is_torch_inline_allowed(filename):
|
|||
return any(filename.startswith(d) for d in get_mod_inlinelist())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def dynamo_dir():
|
||||
import torch._dynamo
|
||||
|
||||
|
|
|
|||
|
|
@ -2369,7 +2369,7 @@ def is_safe_constant(v):
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def common_constants():
|
||||
return {
|
||||
# We zero-one specialize shapes, so specialize these constants
|
||||
|
|
@ -3111,7 +3111,7 @@ seen_code_map = ExactWeakKeyDictionary()
|
|||
|
||||
|
||||
# return same dir unless user changes config between calls
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_debug_dir(root_dir):
|
||||
dir_name = (
|
||||
"run_"
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class VariableBuilder:
|
|||
return cls._type_dispatch_impl(config.trace_numpy)
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _type_dispatch_impl(cls, trace_numpy):
|
||||
# NB: Careful not to close over self to avoid ref cycle from lru_cache
|
||||
entries = [
|
||||
|
|
@ -576,7 +576,7 @@ class VariableBuilder:
|
|||
return self.tx.output.side_effects.track_mutable(value, result)
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _id_dispatch(
|
||||
cls,
|
||||
) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]:
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class BuiltinVariable(VariableTracker):
|
|||
return cls(value, source=source)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _constant_fold_functions():
|
||||
fns = {
|
||||
abs,
|
||||
|
|
@ -239,7 +239,7 @@ class BuiltinVariable(VariableTracker):
|
|||
return self.fn in self._constant_fold_functions()
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _fx_graph_functions():
|
||||
fns = {
|
||||
operator.abs,
|
||||
|
|
@ -285,7 +285,7 @@ class BuiltinVariable(VariableTracker):
|
|||
return fns
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _binops() -> dict[
|
||||
Callable[..., object], tuple[list[str], Callable[..., object]]
|
||||
]:
|
||||
|
|
@ -324,7 +324,7 @@ class BuiltinVariable(VariableTracker):
|
|||
return fns
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _binop_handlers():
|
||||
# Multiple dispatch mechanism defining custom binop behavior for certain type
|
||||
# combinations. Handlers are attempted in order, and will be used if the type checks
|
||||
|
|
|
|||
|
|
@ -1805,7 +1805,7 @@ class PolyfilledFunctionVariable(VariableTracker):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
|
||||
return {}
|
||||
|
||||
|
|
|
|||
|
|
@ -900,7 +900,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
self.nn_module_stack_source = source
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _nn_module_method_ids():
|
||||
# Allow __setattr__ to fall through to base class handler
|
||||
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
|
||||
|
|
|
|||
|
|
@ -1040,7 +1040,7 @@ class TensorVariable(VariableTracker):
|
|||
return wrap_fx_proxy(tx, proxy)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _warn_capture_scalar_outputs():
|
||||
user_stack = torch._guards.TracingContext.extract_stack()
|
||||
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need
|
|||
constant_fold_functions = dict.fromkeys(constant_fold_functions)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]:
|
||||
# Defined as a function to avoid circular import like torch.onnx
|
||||
return {
|
||||
|
|
@ -197,7 +197,7 @@ dispatch_key_set_functions = {
|
|||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_overridable_functions():
|
||||
from itertools import chain
|
||||
|
||||
|
|
@ -432,7 +432,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
return self.value
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_handlers():
|
||||
"""Build a dict from function -> method to handle it so that we are O(1)
|
||||
in terms of the number of function with special handling."""
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ banned_attrs = [
|
|||
]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_prev_stack_var_name():
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
return f"{self.__class__.__name__}({self.value})"
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _constant_fold_classes():
|
||||
return {
|
||||
torch.device,
|
||||
|
|
@ -159,7 +159,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _in_graph_classes():
|
||||
_in_graph_class_list = {
|
||||
torch.Tensor,
|
||||
|
|
@ -177,7 +177,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
return set(tensortype_to_dtype.keys()) | _in_graph_class_list
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def supported_c_new_functions():
|
||||
exceptions = [
|
||||
getattr(builtins, name).__new__
|
||||
|
|
@ -843,7 +843,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _supported_random_functions():
|
||||
fns = {
|
||||
random.random,
|
||||
|
|
|
|||
|
|
@ -1229,7 +1229,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
remote_cache.put(key, cache_data)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
|
||||
"""
|
||||
Attempts to load the remote cache, returns None on error.
|
||||
|
|
|
|||
|
|
@ -1052,7 +1052,7 @@ def _count_ops(graph: fx.Graph):
|
|||
log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def pointwise_ops():
|
||||
ops = []
|
||||
for attr_name in dir(torch.ops.aten):
|
||||
|
|
|
|||
|
|
@ -874,7 +874,7 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|||
return f"{self.kernel_name=}"
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_tuning_process_pool() -> TuningProcessPool:
|
||||
pool = TuningProcessPool()
|
||||
atexit.register(pool.shutdown)
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ def get_kernel_bin_format(device: str) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
|
||||
return (
|
||||
Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"]))
|
||||
|
|
@ -187,7 +187,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
|
|||
|
||||
class CacheBase:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_system() -> dict[str, Any]:
|
||||
try:
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
|
@ -226,7 +226,7 @@ class CacheBase:
|
|||
|
||||
@staticmethod
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_local_cache_path() -> Path:
|
||||
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
|
||||
|
||||
|
|
@ -280,7 +280,7 @@ class LocalCache(CacheBase):
|
|||
|
||||
|
||||
class PersistentCache(CacheBase):
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
@functools.cache # noqa: B019
|
||||
def get_global_cache(self) -> dict[str, Any]:
|
||||
global_cache_path = self.get_global_cache_path()
|
||||
if global_cache_path is None or not global_cache_path.is_file():
|
||||
|
|
@ -1567,7 +1567,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||
pass
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
||||
"""Returns the path where the AOT Inductor compiled kernels are stored."""
|
||||
if path.endswith(".so"):
|
||||
|
|
@ -2283,7 +2283,7 @@ _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
|
|||
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _precompile_header(
|
||||
header: str,
|
||||
hashable_cmd_line: str,
|
||||
|
|
@ -2969,7 +2969,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||
return glue_code
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def config_hash(cls) -> str:
|
||||
command_gen = CppBuilder(
|
||||
name="O",
|
||||
|
|
@ -3013,7 +3013,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||
raise RuntimeError(errmsg)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def find_libautoschedule(name: str) -> str:
|
||||
sofile = f"libautoschedule_{name.lower()}.so"
|
||||
if "HALIDE_LIB" in os.environ:
|
||||
|
|
@ -3026,7 +3026,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||
return HalideCodeCache._search_for_file(sofile, errmsg)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def find_header(name: str) -> str:
|
||||
if "HALIDE_INCLUDE" in os.environ:
|
||||
path = os.path.join(os.environ["HALIDE_INCLUDE"], name)
|
||||
|
|
@ -3300,7 +3300,7 @@ class PyCodeCache:
|
|||
cls.modules_no_attr.clear()
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def stack_frames_for_code(
|
||||
cls, path: str, lineno: int
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ def get_wrapper_codegen_for_device(
|
|||
return None
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def init_backend_registration() -> None:
|
||||
from .cpp import CppScheduling
|
||||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
|
@ -2223,7 +2223,7 @@ class OptimizationContext:
|
|||
ops_name: str = ""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def jinja2_env() -> Any:
|
||||
try:
|
||||
import jinja2
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ from .cpp_utils import (
|
|||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_export_declaration():
|
||||
return "__declspec(dllexport)" if _IS_WINDOWS else ""
|
||||
|
||||
|
|
|
|||
|
|
@ -274,12 +274,12 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
):
|
||||
code = self.prefix
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def sizeof(name):
|
||||
self.codegen_input_size_var_decl(code, name)
|
||||
return f"{name}_size"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def strideof(name):
|
||||
self.codegen_input_stride_var_decl(code, name)
|
||||
return f"{name}_stride"
|
||||
|
|
@ -1469,7 +1469,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
self.used_cached_memory_formats.add(memory_format_str)
|
||||
return f"cached_torch_memory_format_{memory_format_str}"
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
@functools.cache # noqa: B019
|
||||
def codegen_int_array_var(
|
||||
self,
|
||||
int_array: str,
|
||||
|
|
|
|||
|
|
@ -40,6 +40,6 @@ def get_cuda_version() -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool:
|
||||
return nvcc_path is not None and shutil.which(nvcc_path) is not None
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def _generate_config_filename(request_key: str) -> str:
|
|||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def maybe_fetch_ops() -> Optional[list[Any]]:
|
||||
"""
|
||||
Fetch ops from databases.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]:
|
||||
"""
|
||||
Generate cutlass presets for the given CUDA arch.
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str:
|
|||
return content
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def try_import_cutlass() -> bool:
|
||||
"""
|
||||
We want to support three ways of passing in CUTLASS:
|
||||
|
|
@ -251,7 +251,7 @@ class CUTLASSArgs:
|
|||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||
# Note: Cache needs to be specific for cuda architecture and version
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class CKTileGemmOperation:
|
|||
return asdict(self).items()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def ops():
|
||||
"""
|
||||
Generate the supported instance dataclasses
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
|||
self.code_hash: Optional[str] = None
|
||||
|
||||
# define this in a closure to make cache local to object
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def simplify_indexing(index: sympy.Expr):
|
||||
index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
|
||||
for tree in self.range_trees:
|
||||
|
|
|
|||
|
|
@ -1311,7 +1311,7 @@ class TritonKernelOverrides(TritonOverrides):
|
|||
self._setup_libdevice_routing()
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _setup_libdevice_routing(cls):
|
||||
"""Set up routing to libdevice implementations for fp64 inputs."""
|
||||
|
||||
|
|
|
|||
|
|
@ -915,7 +915,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.write_get_raw_stream
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def add_import_once(line: str) -> None:
|
||||
self.imports.writeline(line)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
|
|
@ -1625,12 +1625,12 @@ class PythonWrapperCodegen(CodeGen):
|
|||
):
|
||||
code = self.prefix
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def sizeof(name):
|
||||
code.writeline(f"{name}_size = {name}.size()")
|
||||
return f"{name}_size"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def strideof(name):
|
||||
code.writeline(f"{name}_stride = {name}.stride()")
|
||||
return f"{name}_stride"
|
||||
|
|
|
|||
|
|
@ -238,12 +238,12 @@ def record_original_output_strides(gm: GraphModule) -> None:
|
|||
output_node.meta["original_output_strides"] = output_strides
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _step_logger() -> Callable[..., None]:
|
||||
return dynamo_logging.get_step_logger(log)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _warn_tf32_disabled() -> None:
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ class _OutOfProcessFxCompile(_SerializedFxCompile):
|
|||
|
||||
# And forward our collected logs. The cache is cleared when the outer
|
||||
# function exits.
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def getLogger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def reset_counters() -> None:
|
|||
call_counter_debug_info.clear()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_env_val(env_str: str) -> Optional[str]:
|
||||
return os.environ.get(env_str, None)
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def install_gcc_via_conda() -> str:
|
|||
return cxx_path
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def check_compiler_exist_windows(compiler: str) -> None:
|
||||
"""
|
||||
Check if compiler is ready, in case end user not activate MSVC environment.
|
||||
|
|
@ -200,13 +200,13 @@ def convert_cubin_to_obj(
|
|||
return obj_file
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_apple_clang(cpp_compiler: str) -> bool:
|
||||
version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8")
|
||||
return "Apple" in version_string.splitlines()[0]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_clang(cpp_compiler: str) -> bool:
|
||||
# Mac OS apple clang maybe named as gcc, need check compiler info.
|
||||
if sys.platform == "darwin":
|
||||
|
|
@ -221,7 +221,7 @@ def _is_clang(cpp_compiler: str) -> bool:
|
|||
return bool(re.search(r"(clang|clang\+\+)", cpp_compiler))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_gcc(cpp_compiler: str) -> bool:
|
||||
# Since "clang++" ends with "g++", the regex match below would validate on it.
|
||||
if _is_clang(cpp_compiler):
|
||||
|
|
@ -229,7 +229,7 @@ def _is_gcc(cpp_compiler: str) -> bool:
|
|||
return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_msvc_cl(cpp_compiler: str) -> bool:
|
||||
if not _IS_WINDOWS:
|
||||
return False
|
||||
|
|
@ -247,7 +247,7 @@ def _is_msvc_cl(cpp_compiler: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_intel_compiler(cpp_compiler: str) -> bool:
|
||||
def _check_minimal_version(compiler_version: TorchVersion) -> None:
|
||||
"""
|
||||
|
|
@ -291,32 +291,32 @@ def _is_intel_compiler(cpp_compiler: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_gcc() -> bool:
|
||||
return _is_gcc(get_cpp_compiler())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_clang() -> bool:
|
||||
return _is_clang(get_cpp_compiler())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_intel_compiler() -> bool:
|
||||
return _is_intel_compiler(get_cpp_compiler())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_apple_clang() -> bool:
|
||||
return _is_apple_clang(get_cpp_compiler())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_msvc_cl() -> bool:
|
||||
return _is_msvc_cl(get_cpp_compiler())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_compiler_version_info(compiler: str) -> str:
|
||||
env = os.environ.copy()
|
||||
env["LC_ALL"] = "C" # Don't localize output
|
||||
|
|
@ -885,7 +885,7 @@ def _get_python_related_args() -> tuple[list[str], list[str]]:
|
|||
return python_include_dirs, python_lib_path
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_conda_llvm_openmp_installed() -> bool:
|
||||
try:
|
||||
command = "conda list llvm-openmp --json"
|
||||
|
|
@ -895,7 +895,7 @@ def is_conda_llvm_openmp_installed() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def homebrew_libomp() -> tuple[bool, str]:
|
||||
try:
|
||||
# check if `brew` is installed
|
||||
|
|
@ -916,7 +916,7 @@ def homebrew_libomp() -> tuple[bool, str]:
|
|||
return False, ""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None:
|
||||
try:
|
||||
output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode(
|
||||
|
|
@ -930,7 +930,7 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None:
|
|||
pass
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def perload_icx_libomp_win(cpp_compiler: str) -> None:
|
||||
def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ cdll.LoadLibrary("__lib_path__")
|
|||
def __bool__(self) -> bool:
|
||||
return self.__bool__impl(config.cpp.vec_isa_ok)
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
@functools.cache # noqa: B019
|
||||
def __bool__impl(self, vec_isa_ok) -> bool:
|
||||
if vec_isa_ok is not None:
|
||||
return vec_isa_ok
|
||||
|
|
@ -241,7 +241,7 @@ extern "C" void __amx_chk_kernel() {
|
|||
}
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
@functools.cache # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if super().__bool__():
|
||||
if config.is_fbcode():
|
||||
|
|
@ -380,7 +380,7 @@ def get_isa_from_cpu_capability(
|
|||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def valid_vec_isa_list() -> list[VecISA]:
|
||||
isa_list: list[VecISA] = []
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
|||
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_dot() -> bool:
|
||||
return shutil.which("dot") is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -856,7 +856,7 @@ def miopen_batch_norm(
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
|
||||
return {**decompositions, **extra_random_decomps}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]
|
|||
# So first decompose CSEVars -> tuple before calling this
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_promoted_dtype(
|
||||
*args: Sequence[tuple[torch.dtype, bool]],
|
||||
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def recover_original_precision_folded_computation_ops(gm):
|
|||
_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def binary_folding_init():
|
||||
_conv_args = [Arg() for _ in range(9)]
|
||||
_addmm_args = [Arg() for _ in range(3)]
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true):
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def addmm_patterns_init():
|
||||
device = next(
|
||||
(gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu"
|
||||
|
|
|
|||
|
|
@ -956,7 +956,7 @@ def _get_sfdp_patterns():
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _sfdp_init():
|
||||
for key, register_replacement_kwargs in _get_sfdp_patterns():
|
||||
gen_register_replacement(key, **register_replacement_kwargs)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ..pattern_matcher import fwd_only, register_replacement
|
|||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _misc_patterns_init():
|
||||
from .joint_graph import patterns as joint_graph_patterns
|
||||
from .post_grad import pass_patterns as post_grad_patterns_all
|
||||
|
|
|
|||
|
|
@ -1398,7 +1398,7 @@ if torch._C._has_mkldnn:
|
|||
user_node.replace_all_uses_with(node)
|
||||
gm.graph.erase_node(user_node)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _mkldnn_fusion_init():
|
||||
# TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
|
||||
# Otherwise even the matmul or innerproduct can not be accelerated with acl
|
||||
|
|
@ -1414,7 +1414,7 @@ if torch._C._has_mkldnn:
|
|||
_register_quantization_lowerings()
|
||||
_register_woq_lowerings()
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _mkldnn_weight_pack_init():
|
||||
if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
|
||||
_register_weight_pack_pass()
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
|
|||
return arithmetic_intensity > machine_balance
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_pad_cache() -> torch._inductor.codecache.LocalCache:
|
||||
return torch._inductor.codecache.LocalCache()
|
||||
|
||||
|
|
@ -851,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _pad_mm_init() -> None:
|
||||
from .joint_graph import patterns
|
||||
|
||||
|
|
|
|||
|
|
@ -3478,7 +3478,7 @@ def _register_qlinear_binary_fusion():
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _register_quantization_weight_pack_pass():
|
||||
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
||||
_register_dequant_promotion()
|
||||
|
|
|
|||
|
|
@ -511,7 +511,7 @@ scaled_mm_device_tma_template = TritonTemplate(
|
|||
|
||||
|
||||
# prevent duplication registration of extern functions
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def lazy_register_extern_choice(fn):
|
||||
return ExternKernelChoice(fn)
|
||||
|
||||
|
|
@ -1175,7 +1175,7 @@ def tuned_scaled_mm(
|
|||
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
|
||||
props = torch.cuda.get_device_properties(index or 0)
|
||||
return props.major <= 7
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ T = TypeVar("T")
|
|||
|
||||
class InterpreterShim(torch.fx.Interpreter):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _dummy_gm():
|
||||
return torch.fx.symbolic_trace(identity)
|
||||
|
||||
|
|
|
|||
|
|
@ -1879,7 +1879,7 @@ def fallback_handler(kernel, add_to_fallback_set=True):
|
|||
return handler
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _warn_complex_not_supported():
|
||||
warnings.warn(
|
||||
"Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
|
||||
|
|
|
|||
|
|
@ -1010,7 +1010,7 @@ class PatternPrettyPrinter:
|
|||
self.memoized_objs_pp: dict[PatternExpr, str] = {}
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def run(obj: PatternExpr, output_name: str = "output") -> str:
|
||||
"""
|
||||
Serializes obj to python code with obj written out to `output_name`
|
||||
|
|
@ -2195,7 +2195,7 @@ def stable_topological_sort(graph: torch.fx.Graph) -> None:
|
|||
def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
|
||||
"""Wrapper around lazy init functions in fx_passes/"""
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
@functools.wraps(fn)
|
||||
def lazy_init() -> Any:
|
||||
counters_ref = counters["inductor"].copy()
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def _reload_python_module(
|
|||
return mod
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _set_triton_ptxas_path() -> None:
|
||||
if os.environ.get("TRITON_PTXAS_PATH") is not None:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ class DeviceProperties(typing.NamedTuple):
|
|||
warp_size: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def create(cls, device) -> DeviceProperties:
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
|
|
|||
|
|
@ -1124,7 +1124,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _jinja2_env():
|
||||
try:
|
||||
import jinja2
|
||||
|
|
@ -1726,7 +1726,7 @@ class ExternKernelChoice:
|
|||
def call_name(self):
|
||||
return f"extern_kernels.{self.name}"
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
@functools.cache # noqa: B019
|
||||
def hash_key(self):
|
||||
fn = self.to_callable()
|
||||
parts = [
|
||||
|
|
@ -1933,7 +1933,7 @@ class ExternKernelCaller(ChoiceCaller):
|
|||
return f"extern_{self.choice.name}"
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_mm_log_filename() -> Optional[str]:
|
||||
mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None)
|
||||
if not mm_file_name:
|
||||
|
|
@ -2052,7 +2052,7 @@ class NoValidChoicesError(RuntimeError):
|
|||
pass
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_num_workers() -> int:
|
||||
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
||||
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||
|
|
@ -2194,7 +2194,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
|
||||
return choices[0].output_node()
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def make_benchmark_fn():
|
||||
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
|
||||
|
||||
|
|
@ -2506,7 +2506,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
future.add_done_callback(on_complete)
|
||||
futures[future] = c
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
@restore_stdout_stderr()
|
||||
def wait_on_futures():
|
||||
log.debug("Waiting on futures")
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ T = TypeVar("T")
|
|||
|
||||
# defines here before import torch._dynamo is for avoiding circular import
|
||||
# when get_gpu_type is imported from dynamo
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_gpu_type() -> str:
|
||||
avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
|
||||
assert len(avail_gpus) <= 1
|
||||
|
|
@ -338,7 +338,7 @@ def do_bench_using_profiling(
|
|||
return res
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_torchvision_roi_align() -> bool:
|
||||
try:
|
||||
from torchvision.ops import roi_align # noqa: F401
|
||||
|
|
@ -1384,7 +1384,7 @@ class DelayReplaceLine(DeferredLineBase):
|
|||
return DelayReplaceLine(self.key, self.value_fn, line)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
|
||||
if isinstance(index_or_device, torch.device):
|
||||
device = index_or_device
|
||||
|
|
@ -1599,7 +1599,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
|
||||
# If k is a sympy expression, we can't do any splitting
|
||||
if isinstance(k, sympy.Expr) and not k.is_number:
|
||||
|
|
@ -1652,12 +1652,12 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
|
|||
return best_splits[:k_splits_limit]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _rocm_native_device_arch_name(device: str) -> str:
|
||||
return torch.cuda.get_device_properties(device).gcnArchName
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def try_import_ck_lib() -> tuple[
|
||||
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
|
||||
]:
|
||||
|
|
@ -2106,7 +2106,7 @@ def parallel_num_threads() -> int:
|
|||
return threads
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_backend_num_stages() -> int:
|
||||
from .runtime.triton_helpers import get_backend_options
|
||||
|
||||
|
|
@ -2114,7 +2114,7 @@ def get_backend_num_stages() -> int:
|
|||
return options.get("num_stages", 2 if torch.version.hip else 3)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_device_tflops(dtype: torch.dtype) -> int:
|
||||
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
|
|
@ -2142,7 +2142,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
|
|||
return get_max_simd_tflops(torch.float32)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_gpu_dram_gbps() -> int:
|
||||
from triton.testing import get_dram_gbps
|
||||
|
||||
|
|
@ -2862,7 +2862,7 @@ def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def boolean_ops() -> tuple[str, ...]:
|
||||
return (
|
||||
"isinf",
|
||||
|
|
@ -3051,7 +3051,7 @@ class TritonAttrsDescriptorVersion(enum.Enum):
|
|||
V4_DICT = 4 # a raw dict
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
|
||||
if importlib.util.find_spec("triton") is None:
|
||||
return TritonAttrsDescriptorVersion.V0_NO_TRITON
|
||||
|
|
|
|||
|
|
@ -1143,7 +1143,7 @@ class LazyTraceHandler(logging.StreamHandler):
|
|||
super().emit(record)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def warning_once(logger_obj, *args, **kwargs) -> None:
|
||||
"""
|
||||
This function is similar to `logger.warning()`, but will emit the warning with the same message only once
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import torch.overrides
|
|||
from torch._prims_common import torch_function_passthrough
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def torch_to_refs_map():
|
||||
"""
|
||||
Mapping of torch API functions to torch._refs functions.
|
||||
|
|
@ -70,7 +70,7 @@ def torch_to_refs_map():
|
|||
return r
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def all_prims():
|
||||
"""
|
||||
Set of all prim functions, e.g., torch._prims.add in all_prims()
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def contains_tensor_types(type):
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _is_tensor_constructor(func: OpOverload):
|
||||
assert isinstance(func, OpOverload)
|
||||
schema = func._schema
|
||||
|
|
@ -1077,7 +1077,7 @@ def fast_detach(fake_mode, x):
|
|||
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_fast_op_impls():
|
||||
import torch._refs
|
||||
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]:
|
|||
return None
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo:
|
||||
return torch._C._SchemaInfo(func._schema)
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo:
|
|||
# torch/_decomp/decompositions.py.
|
||||
# decomps are used for aot autograd tracing so we would like to unify on their
|
||||
# implementation and add additional testing to them
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def torch_decomp_decompositions(func: OpOverload) -> bool:
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
|
|
@ -511,7 +511,7 @@ class FakeTensorConverter:
|
|||
return out
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def init_gpu_context(device: torch.device) -> None:
|
||||
# Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first
|
||||
if torch.cuda.is_available() or torch.xpu.is_available():
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ def is_fb_unit_test() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def max_clock_rate():
|
||||
if not torch.version.hip:
|
||||
from triton.testing import nvsmi
|
||||
|
|
|
|||
|
|
@ -7379,7 +7379,7 @@ class ShapeEnv:
|
|||
|
||||
# Don't track this one. (Because this cache is inside this function the
|
||||
# cache only lasts for the invocation of this function call)
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def compute_concrete_val() -> sympy.Basic:
|
||||
if hint is None:
|
||||
# This is only ever called for expressions WITHOUT unbacked
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_safetensors_and_transformers():
|
||||
try:
|
||||
# safetensors is not an exporter requirement, but needed for some huggingface models
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def _disable_user_warnings(
|
|||
return wrapper
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
@_disable_user_warnings
|
||||
def get_ignored_functions() -> set[Callable]:
|
||||
"""
|
||||
|
|
@ -378,7 +378,7 @@ def get_ignored_functions() -> set[Callable]:
|
|||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_default_nowrap_functions() -> set[Callable]:
|
||||
"""
|
||||
Return public functions that do not wrap in a subclass when invoked by
|
||||
|
|
@ -404,7 +404,7 @@ def get_default_nowrap_functions() -> set[Callable]:
|
|||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
@_disable_user_warnings
|
||||
def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
"""Return a dict containing dummy overrides for all overridable functions
|
||||
|
|
@ -1808,7 +1808,7 @@ has_torch_function_variadic = _add_docstr(
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_overridable_functions() -> tuple[
|
||||
dict[Any, list[Callable]], dict[Callable, str]
|
||||
]:
|
||||
|
|
@ -1929,7 +1929,7 @@ def resolve_name(f):
|
|||
return _get_overridable_functions()[1].get(f)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_tensor_methods() -> set[Callable]:
|
||||
"""Returns a set of the overridable methods on ``torch.Tensor``"""
|
||||
overridable_funcs = get_overridable_functions()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import functools
|
|||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_helion_package() -> bool:
|
||||
try:
|
||||
import helion # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
|
|
@ -12,6 +12,6 @@ def has_helion_package() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_helion() -> bool:
|
||||
return has_helion_package() and has_triton()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ log = logging.getLogger(__name__)
|
|||
# TODO: Dedupe this with SYMPY_INTERP
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def handlers():
|
||||
# TODO add CeilDiv (it doesn't appear in the index_expr)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import hashlib
|
|||
from typing import Any
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton_package() -> bool:
|
||||
try:
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
|
@ -15,7 +15,7 @@ def has_triton_package() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _device_supports_tma() -> bool:
|
||||
import torch
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ def _device_supports_tma() -> bool:
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton_experimental_host_tma() -> bool:
|
||||
if has_triton_package():
|
||||
if _device_supports_tma():
|
||||
|
|
@ -43,7 +43,7 @@ def has_triton_experimental_host_tma() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton_tensor_descriptor_host_tma() -> bool:
|
||||
if has_triton_package():
|
||||
if _device_supports_tma():
|
||||
|
|
@ -59,12 +59,12 @@ def has_triton_tensor_descriptor_host_tma() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton_tma() -> bool:
|
||||
return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton_tma_device() -> bool:
|
||||
if has_triton_package():
|
||||
import torch
|
||||
|
|
@ -87,7 +87,7 @@ def has_triton_tma_device() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def has_triton() -> bool:
|
||||
if not has_triton_package():
|
||||
return False
|
||||
|
|
@ -121,7 +121,7 @@ def has_triton() -> bool:
|
|||
return is_device_compatible_with_triton()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def triton_backend() -> Any:
|
||||
from triton.compiler.compiler import make_backend
|
||||
from triton.runtime.driver import driver
|
||||
|
|
@ -130,7 +130,7 @@ def triton_backend() -> Any:
|
|||
return make_backend(target)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def triton_hash_with_backend() -> str:
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user