pytorch/torch/_dynamo/precompile_context.py
James Wu bfe9e60ffb Simplify PrecompileContext to no longer be a CacheArtifactManager (#162886)
Summary:
This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache.

This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries.

The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries().

Saving single entries and such with DynamoCache still works normally.

Test Plan:
All existing unit tests pass.

Rollback Plan:

Differential Revision: D82380307

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886
Approved by: https://github.com/zhxchen17
2025-09-20 01:24:37 +00:00

245 lines
8.2 KiB
Python

import copy
import json
import logging
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar
import torch
from torch._dynamo.package import (
_BackendId,
_DynamoCacheEntry,
DynamoCache,
PrecompileCacheEntry,
)
"""
Classes and implementations related to precompile
"""
T = TypeVar("T")
logger = logging.getLogger(__name__)
@dataclass
class BackendCacheArtifact(Generic[T]):
"""
Represents a single serializable backend artifact from a dynamo backend.
Each BackendCacheArtifact has a key associated with it along with some
serializable content.
Example implementation:
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
my_field: int
def after_deserialization(self) -> MySerializableType:
result = pickle.loads(self.content)
# Do some extra work post deserialization
result.my_post_deserialization_function(self.my_field)
return result
"""
key: str
content: Any
@abstractmethod
def after_deserialization(self) -> T:
"""
Code to be run after reading raw byte contents from disk.
Generally converts self.content from raw bytes back into its original form.
"""
...
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
"""
Edit the contents of the artifact.
"""
self.content = edit_fn(self.content)
class EagerCacheArtifact(BackendCacheArtifact[Any]):
def after_deserialization(self) -> Any:
return self.content
class BypassDynamoCacheEntry(Exception):
pass
class PrecompileContext:
"""
PrecompileContext is a special CacheArtifactManager for handling precompilation
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
together and place it into a global Precompile Cache.
PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts.
When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that
are needed to save those dynamo cache entries.
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
"""
# Protected by the compile_lock
# _backend_artifacts_by_key organizes results by the key of each artifact.
# Each object here must be serializable
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
_dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
@classmethod
def clear(cls) -> None:
cls._backend_artifacts_by_key.clear()
cls._dynamo_cache_entries.clear()
@classmethod
def record_artifact(
cls,
artifact: BackendCacheArtifact[Any],
) -> None:
"""
Records a backend artifact to be used with dynamo cache entries
"""
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
@classmethod
def record_dynamo_cache_entry(
cls, cache_entry: _DynamoCacheEntry, key: str
) -> None:
cls._dynamo_cache_entries[key] = cache_entry
@classmethod
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
"""
Edit the content of an existing artifact
"""
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
artifact = cls._backend_artifacts_by_key[key]
artifact.edit_contents(edit_fn)
@classmethod
def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
"""
Return the backend cache artifact with the associated key
"""
return cls._backend_artifacts_by_key.get(key, None)
@staticmethod
def dump_debug_info(
dynamo_entries: dict[str, _DynamoCacheEntry],
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
) -> dict[str, Any]:
"""
Return a JSON serializable debug dump of all entries in the precompile context
Called in serialize before serialization, and in populate_caches after deserialization
"""
# Print debug information
debug_info: defaultdict[str, list[Any]] = defaultdict(list)
for key, cache_entry in dynamo_entries.items():
info = cache_entry.debug_info()
info["key"] = key
debug_info["dynamo"].append(info)
for artifact in backend_artifacts.values():
debug_info["backends"].append(artifact.key)
return debug_info
@classmethod
def save_to_dynamo_cache(cls) -> dict[str, Any]:
precompile_cache_entries, debug_info = cls.create_cache_entries()
for key, entry in precompile_cache_entries.items():
DynamoCache.write(entry, key)
return debug_info
@classmethod
def create_cache_entries(
cls,
) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
"""
Grabs all the cache entries in the precompile context and
stitches them together into full PrecompileCacheEntries.
"""
dynamo_entries = cls._dynamo_cache_entries
backend_artifacts = cls._backend_artifacts_by_key
num_artifacts = len(dynamo_entries)
debug_info = PrecompileContext.dump_debug_info(
dynamo_entries, backend_artifacts
)
debug_str = json.dumps(
{
"num_entries": num_artifacts,
"artifacts": debug_info,
},
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_entries",
"encoding": "json",
},
payload_fn=lambda: debug_str,
expect_trace_id=False,
)
precompile_cache_entries = {}
for key, cache_entry in dynamo_entries.items():
try:
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
for id_ in backends:
if id_ not in backend_artifacts:
debug_str = json.dumps(
{
"entry": cache_entry.debug_info,
"key": key,
}
)
logger.warning("Backend not found")
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_bypass",
"encoding": "json",
},
payload_fn=lambda: debug_str,
expect_trace_id=False,
)
continue
artifact = backend_artifacts[id_]
assert isinstance(artifact, BackendCacheArtifact)
backend_content[id_] = artifact
precompile_cache_entries[key] = PrecompileCacheEntry(
dynamo=cache_entry, backends=backend_content
)
except Exception as e:
logger.warning("Failed to create cache entry %s: %s", key, str(e))
error = e
data = json.dumps(
{
"key": key,
"error": str(error),
}
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_exception",
"encoding": "json",
},
payload_fn=lambda: data,
)
continue
return precompile_cache_entries, debug_info