import copy import json import logging from abc import abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, Generic, Optional, TypeVar import torch from torch._dynamo.package import ( _BackendId, _DynamoCacheEntry, DynamoCache, PrecompileCacheEntry, ) """ Classes and implementations related to precompile """ T = TypeVar("T") logger = logging.getLogger(__name__) @dataclass class BackendCacheArtifact(Generic[T]): """ Represents a single serializable backend artifact from a dynamo backend. Each BackendCacheArtifact has a key associated with it along with some serializable content. Example implementation: class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]): my_field: int def after_deserialization(self) -> MySerializableType: result = pickle.loads(self.content) # Do some extra work post deserialization result.my_post_deserialization_function(self.my_field) return result """ key: str content: Any @abstractmethod def after_deserialization(self) -> T: """ Code to be run after reading raw byte contents from disk. Generally converts self.content from raw bytes back into its original form. """ ... def edit_contents(self, edit_fn: Callable[..., Any]) -> None: """ Edit the contents of the artifact. """ self.content = edit_fn(self.content) class EagerCacheArtifact(BackendCacheArtifact[Any]): def after_deserialization(self) -> Any: return self.content class BypassDynamoCacheEntry(Exception): pass class PrecompileContext: """ PrecompileContext is a special CacheArtifactManager for handling precompilation It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key together and place it into a global Precompile Cache. PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts. When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that are needed to save those dynamo cache entries. The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact """ # Protected by the compile_lock # _backend_artifacts_by_key organizes results by the key of each artifact. # Each object here must be serializable _backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {} # On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted # into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization _dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {} @classmethod def clear(cls) -> None: cls._backend_artifacts_by_key.clear() cls._dynamo_cache_entries.clear() @classmethod def record_artifact( cls, artifact: BackendCacheArtifact[Any], ) -> None: """ Records a backend artifact to be used with dynamo cache entries """ cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy( artifact ) @classmethod def record_dynamo_cache_entry( cls, cache_entry: _DynamoCacheEntry, key: str ) -> None: cls._dynamo_cache_entries[key] = cache_entry @classmethod def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None: """ Edit the content of an existing artifact """ assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts" artifact = cls._backend_artifacts_by_key[_BackendId(key)] artifact.edit_contents(edit_fn) @classmethod def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]: """ Return the backend cache artifact with the associated key """ return cls._backend_artifacts_by_key.get(_BackendId(key), None) @staticmethod def dump_debug_info( dynamo_entries: dict[str, _DynamoCacheEntry], backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]], ) -> dict[str, Any]: """ Return a JSON serializable debug dump of all entries in the precompile context Called in serialize before serialization, and in populate_caches after deserialization """ # Print debug information debug_info: defaultdict[str, list[Any]] = defaultdict(list) for key, cache_entry in dynamo_entries.items(): info = cache_entry.debug_info() info["key"] = key debug_info["dynamo"].append(info) for artifact in backend_artifacts.values(): debug_info["backends"].append(artifact.key) return debug_info @classmethod def save_to_dynamo_cache(cls) -> dict[str, Any]: precompile_cache_entries, debug_info = cls.create_cache_entries() for key, entry in precompile_cache_entries.items(): DynamoCache.write(entry, key) return debug_info @classmethod def create_cache_entries( cls, ) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]: """ Grabs all the cache entries in the precompile context and stitches them together into full PrecompileCacheEntries. """ dynamo_entries = cls._dynamo_cache_entries backend_artifacts = cls._backend_artifacts_by_key num_artifacts = len(dynamo_entries) debug_info = PrecompileContext.dump_debug_info( dynamo_entries, backend_artifacts ) debug_str = json.dumps( { "num_entries": num_artifacts, "artifacts": debug_info, }, ) torch._logging.trace_structured( "artifact", metadata_fn=lambda: { "name": "dynamo_cache_entries", "encoding": "json", }, payload_fn=lambda: debug_str, expect_trace_id=False, ) precompile_cache_entries = {} for key, cache_entry in dynamo_entries.items(): try: result = PrecompileCacheEntry.from_cache_entry( cache_entry, backend_artifacts ) if result is not None: precompile_cache_entries[key] = result except Exception as e: logger.warning("Failed to create cache entry %s: %s", key, str(e)) error = e data = json.dumps( { "key": key, "error": str(error), } ) torch._logging.trace_structured( "artifact", metadata_fn=lambda: { "name": "dynamo_cache_exception", "encoding": "json", }, payload_fn=lambda: data, ) continue return precompile_cache_entries, debug_info