from __future__ import annotations import pickle from abc import ABC, abstractmethod from ast import literal_eval from functools import cached_property from hashlib import sha256 from os import getenv from pathlib import Path from tempfile import gettempdir from threading import Lock from typing import Any, Generic, TYPE_CHECKING, TypeVar from typing_extensions import assert_never, override, Self from torch.utils._filelock import FileLock if TYPE_CHECKING: from concurrent.futures import Future, ThreadPoolExecutor # TypeVars can't be recursive, so generic types that fall within # Key or Value can't be bound properly; for example, Key should # only take tuples of other Key types: tuple[Key, ...]. this is # a known shortcoming of torch's typing Key = TypeVar("Key", str, int, tuple[Any, ...]) Value = TypeVar("Value", str, int, tuple[Any, ...], bytes, dict[Any, Any], list[Any]) class CacheError(ValueError): """ Exception raised for errors encountered during cache operations. """ class Cache(ABC, Generic[Key, Value]): """ Abstract base class for cache implementations. Provides the interface for cache operations. """ @abstractmethod def get(self: Self, key: Key) -> Value | None: """ Retrieve a value from the cache. Args: key (Key): The key to look up. Returns: Value | None: The cached value if present, else None. """ @abstractmethod def insert(self: Self, key: Key, value: Value) -> bool: """ Insert a value into the cache. Args: key (Key): The key to insert. value (Value): The value to associate with the key. Returns: bool: True if the value was inserted, False if the key already exists. """ class InMemoryCache(Cache[Key, Value]): """ In-memory cache implementation using a dictionary and thread lock. """ def __init__(self: Self) -> None: """ Initialize an empty in-memory cache. """ self._cache: dict[Key, Value] = {} self._lock: Lock = Lock() def get(self: Self, key: Key) -> Value | None: """ Retrieve a value from the cache. Args: key (Key): The key to look up. Returns: Value | None: The cached value if present, else None. """ with self._lock: if (value := self._cache.get(key)) is not None: return value return None def insert(self: Self, key: Key, value: Value) -> bool: """ Insert a value into the cache. Args: key (Key): The key to insert. value (Value): The value to associate with the key. Returns: bool: True if the value was inserted, False if the key already exists. """ with self._lock: if key in self._cache: # no overwrites for insert! return False self._cache[key] = value return True @classmethod def from_env_var(cls, env_var: str) -> Self: """ Create an in-memory cache from an environment variable. Args: env_var (str): Name of the environment variable containing cache data. Returns: InMemoryCache: An instance populated from the environment variable. Raises: CacheError: If the environment variable is malformed or contains invalid data. """ cache = cls() if (env_val := getenv(env_var)) is None: # env_var doesn't exist = empty cache return cache for kv_pair in env_val.split(";"): # ignore whitespace prefix/suffix kv_pair = kv_pair.strip() if not kv_pair: # kv_pair could be '' if env_val is '' or has ; suffix continue try: # keys and values should be comma separated key_bytes_repr, value_bytes_repr = kv_pair.split(",", 1) except ValueError as err: raise CacheError( f"Malformed kv_pair {kv_pair!r} from env_var {env_var!r}, likely missing comma separator." ) from err # ignore whitespace prefix/suffix, again key_bytes_repr, value_bytes_repr = ( key_bytes_repr.strip(), value_bytes_repr.strip(), ) try: # check that key_bytes_str is an actual, legitimate encoding key_bytes = literal_eval(key_bytes_repr) except (ValueError, SyntaxError) as err: raise CacheError( f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid." ) from err try: # check that value_bytes_str is an actual, legitimate encoding value_bytes = literal_eval(value_bytes_repr) except (ValueError, SyntaxError) as err: raise CacheError( f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid." ) from err try: key = pickle.loads(key_bytes) except pickle.UnpicklingError as err: raise CacheError( f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able." ) from err try: value = pickle.loads(value_bytes) except pickle.UnpicklingError as err: raise CacheError( f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able." ) from err # true duplicates, i.e. multiple occurrences of the same key => value # mapping are ok and treated as a no-op; key duplicates with differing # values, i.e. key => value_1 and key => value_2 where value_1 != value_2, # are not okay since we don't allow overwriting cached values (it's bad regardless) if (not cache.insert(key, value)) and (cache.get(key) != value): raise CacheError( f"Multiple values for key {key!r} found, got {cache.get(key)!r} and {value!r}." ) return cache @classmethod def from_file_path(cls, fpath: Path) -> Self: """ Create an in-memory cache from a file path. Args: fpath (Path): Path to the file containing pickled cache data. Returns: InMemoryCache: An instance populated from the file. Raises: CacheError: If the file is not a valid pickled dictionary. """ cache = cls() if not fpath.is_file(): # fpath doesn't exit = empty cache return cache try: with open(fpath, "rb") as fp: cache._cache = pickle.load(fp) except pickle.UnpicklingError as err: raise CacheError( f"Failed to create cache from file path {fpath}, file contents are un-pickle-able." ) from err if not isinstance(cache._cache, dict): raise CacheError( f"Failed to create cache from file path {fpath}, file contents not pickled dict[Key, Value]." ) return cache class AsyncCache(Cache[Key, Value]): """ Asynchronous cache implementation using ThreadPoolExecutor. """ def get_async( self: Self, key: Key, executor: ThreadPoolExecutor ) -> Future[Value | None]: """ Retrieve a value from the cache asynchronously. Args: key (Key): The key to look up. executor (ThreadPoolExecutor): Executor for async execution. Returns: Future[Value | None]: Future for the cached value or None. """ return executor.submit(self.get, key) def insert_async( self: Self, key: Key, value: Value, executor: ThreadPoolExecutor ) -> Future[bool]: """ Insert a value into the cache asynchronously. Args: key (Key): The key to insert. value (Value): The value to associate with the key. executor (ThreadPoolExecutor): Executor for async execution. Returns: Future[bool]: Future for the result of insertion. """ return executor.submit(self.insert, key, value) class OnDiskCache(AsyncCache[Key, Value]): """ On-disk cache implementation using files and file locks. Stores cache data in files on disk, with atomic operations and versioning. Supports custom cache directory names. Attributes: version (int): The version used for cache versioning. name (str): The name of the cache directory. """ version: int = 0 def __init__(self: Self, name: str | None = None) -> None: """ Initialize an on-disk cache instance. Args: name (str | None, optional): The name of the cache directory. If None, defaults to "on_disk_cache". """ self.name = name or "on_disk_cache" @cached_property def base_dir(self: Self) -> Path: """ Get the base directory for the cache. Returns: Path: The base directory path for storing cache files. """ return Path(gettempdir()) / "cache" / self.name def _fpath_from_key(self: Self, key: Key) -> Path: """ Get the file path for a given key. Args: key (Key): The key to convert to a file path. Returns: Path: The file path for the key. Raises: CacheError: If the key is not pickle-able. """ try: return self.base_dir / sha256(pickle.dumps(key)).hexdigest()[:32] except (AttributeError, pickle.PicklingError) as err: raise CacheError( f"Failed to get fpath for key {key!r}, key is not pickle-able." ) from err # pyrefly: ignore [bad-argument-type] assert_never(key) def _flock_from_fpath(self: Self, fpath: Path) -> FileLock: """ Get a file lock for a given file path. Args: fpath (Path): The file path. Returns: FileLock: The file lock for the path. """ # fpath.name is a hex digest, meaning there are 16^4 potential values # for fpath.name[:4]; this is more than enough unique locks to not # cause additional overhead from shared locks and it also saves our # cache dir from becoming 50 percent locks # pyrefly: ignore [bad-return] return FileLock(str(fpath.parent / "locks" / fpath.name[:4]) + ".lock") @property def version_prefix(self: Self) -> bytes: """ Get the version prefix for the cache. Returns: bytes: The version prefix as bytes, derived from the cache version string. """ return sha256(str(OnDiskCache.version).encode()).digest()[:4] @override def get(self: Self, key: Key) -> Value | None: """ Retrieve a value from the cache. Args: key (Key): The key to look up. Returns: Value | None: The cached value if present and version matches, else None. Raises: CacheError: If the value is corrupted or cannot be unpickled. Side Effects: Removes stale cache files if the version prefix does not match. """ fpath = self._fpath_from_key(key) flock = self._flock_from_fpath(fpath) with flock: if not fpath.is_file(): return None value_bytes = None prefix_length = len(self.version_prefix) with open(fpath, "rb") as fp: if fp.read(prefix_length) == self.version_prefix: value_bytes = fp.read() if value_bytes is None: # version_prefix did not match, so we can't read the stale # cached value; we should also remove the stale cached value, # so that key can be re-cached by the newer version fpath.unlink() return None try: value = pickle.loads(value_bytes) except pickle.UnpicklingError as err: raise CacheError( f"Failed to get key {key!r}, value is potentially corrupted (value is not un-pickle-able)." ) from err return value @override def insert(self: Self, key: Key, value: Value) -> bool: """ Insert a value into the cache. Args: key (Key): The key to insert. value (Value): The value to associate with the key. Returns: bool: True if the value was inserted, False if the key already exists. Raises: CacheError: If the value is not pickle-able. Side Effects: Creates the cache directory if it does not exist. """ fpath = self._fpath_from_key(key) flock = self._flock_from_fpath(fpath) fpath.parent.mkdir(parents=True, exist_ok=True) try: # "x" mode is exclusive creation, meaning the file will be created # iff the file does not already exist (atomic w/o overwrite); use # flock for added atomicity guarantee and to prevent partial writes with flock as _, open(fpath, "xb") as fp: fp.write(self.version_prefix) pickle.dump(value, fp) except pickle.PicklingError as err: raise CacheError( f"Failed to insert key {key!r} with value {value!r}, value is not pickle-able." ) from err except FileExistsError: return False return True class InductorOnDiskCache(OnDiskCache[Key, Value]): """ Inductor-specific on-disk cache implementation. Uses a custom base directory for Inductor cache files. """ def __init__(self: Self) -> None: """ Initialize an inductor on-disk cache instance. Sets the cache directory name to "inductor_on_disk_cache". """ super().__init__("inductor_on_disk_cache") @cached_property def base_dir(self: Self) -> Path: """ Get the base directory for the Inductor cache. Returns: Path: The base directory path for Inductor cache files. """ from torch._inductor.runtime.runtime_utils import default_cache_dir return Path(default_cache_dir(), "cache", self.name)