mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache. This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries. The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries(). Saving single entries and such with DynamoCache still works normally. Test Plan: All existing unit tests pass. Rollback Plan: Differential Revision: D82380307 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886 Approved by: https://github.com/zhxchen17
245 lines
8.2 KiB
Python
245 lines
8.2 KiB
Python
import copy
|
|
import json
|
|
import logging
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Generic, Optional, TypeVar
|
|
|
|
import torch
|
|
from torch._dynamo.package import (
|
|
_BackendId,
|
|
_DynamoCacheEntry,
|
|
DynamoCache,
|
|
PrecompileCacheEntry,
|
|
)
|
|
|
|
|
|
"""
|
|
Classes and implementations related to precompile
|
|
"""
|
|
|
|
T = TypeVar("T")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BackendCacheArtifact(Generic[T]):
|
|
"""
|
|
Represents a single serializable backend artifact from a dynamo backend.
|
|
Each BackendCacheArtifact has a key associated with it along with some
|
|
serializable content.
|
|
|
|
Example implementation:
|
|
|
|
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
|
|
my_field: int
|
|
|
|
def after_deserialization(self) -> MySerializableType:
|
|
result = pickle.loads(self.content)
|
|
# Do some extra work post deserialization
|
|
result.my_post_deserialization_function(self.my_field)
|
|
return result
|
|
"""
|
|
|
|
key: str
|
|
content: Any
|
|
|
|
@abstractmethod
|
|
def after_deserialization(self) -> T:
|
|
"""
|
|
Code to be run after reading raw byte contents from disk.
|
|
Generally converts self.content from raw bytes back into its original form.
|
|
"""
|
|
...
|
|
|
|
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the contents of the artifact.
|
|
"""
|
|
self.content = edit_fn(self.content)
|
|
|
|
|
|
class EagerCacheArtifact(BackendCacheArtifact[Any]):
|
|
def after_deserialization(self) -> Any:
|
|
return self.content
|
|
|
|
|
|
class BypassDynamoCacheEntry(Exception):
|
|
pass
|
|
|
|
|
|
class PrecompileContext:
|
|
"""
|
|
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
|
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
|
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
|
|
together and place it into a global Precompile Cache.
|
|
|
|
PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts.
|
|
When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that
|
|
are needed to save those dynamo cache entries.
|
|
|
|
The following artifact types are supported by PrecompileContext:
|
|
- BundledAOTAutogradCacheArtifact
|
|
|
|
"""
|
|
|
|
# Protected by the compile_lock
|
|
# _backend_artifacts_by_key organizes results by the key of each artifact.
|
|
# Each object here must be serializable
|
|
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
|
|
|
|
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
|
|
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
|
|
_dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls._backend_artifacts_by_key.clear()
|
|
cls._dynamo_cache_entries.clear()
|
|
|
|
@classmethod
|
|
def record_artifact(
|
|
cls,
|
|
artifact: BackendCacheArtifact[Any],
|
|
) -> None:
|
|
"""
|
|
Records a backend artifact to be used with dynamo cache entries
|
|
"""
|
|
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
|
|
|
|
@classmethod
|
|
def record_dynamo_cache_entry(
|
|
cls, cache_entry: _DynamoCacheEntry, key: str
|
|
) -> None:
|
|
cls._dynamo_cache_entries[key] = cache_entry
|
|
|
|
@classmethod
|
|
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the content of an existing artifact
|
|
"""
|
|
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
|
|
artifact = cls._backend_artifacts_by_key[key]
|
|
artifact.edit_contents(edit_fn)
|
|
|
|
@classmethod
|
|
def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
|
|
"""
|
|
Return the backend cache artifact with the associated key
|
|
"""
|
|
return cls._backend_artifacts_by_key.get(key, None)
|
|
|
|
@staticmethod
|
|
def dump_debug_info(
|
|
dynamo_entries: dict[str, _DynamoCacheEntry],
|
|
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Return a JSON serializable debug dump of all entries in the precompile context
|
|
Called in serialize before serialization, and in populate_caches after deserialization
|
|
"""
|
|
# Print debug information
|
|
debug_info: defaultdict[str, list[Any]] = defaultdict(list)
|
|
for key, cache_entry in dynamo_entries.items():
|
|
info = cache_entry.debug_info()
|
|
info["key"] = key
|
|
debug_info["dynamo"].append(info)
|
|
|
|
for artifact in backend_artifacts.values():
|
|
debug_info["backends"].append(artifact.key)
|
|
|
|
return debug_info
|
|
|
|
@classmethod
|
|
def save_to_dynamo_cache(cls) -> dict[str, Any]:
|
|
precompile_cache_entries, debug_info = cls.create_cache_entries()
|
|
for key, entry in precompile_cache_entries.items():
|
|
DynamoCache.write(entry, key)
|
|
return debug_info
|
|
|
|
@classmethod
|
|
def create_cache_entries(
|
|
cls,
|
|
) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
|
|
"""
|
|
Grabs all the cache entries in the precompile context and
|
|
stitches them together into full PrecompileCacheEntries.
|
|
"""
|
|
dynamo_entries = cls._dynamo_cache_entries
|
|
backend_artifacts = cls._backend_artifacts_by_key
|
|
|
|
num_artifacts = len(dynamo_entries)
|
|
|
|
debug_info = PrecompileContext.dump_debug_info(
|
|
dynamo_entries, backend_artifacts
|
|
)
|
|
debug_str = json.dumps(
|
|
{
|
|
"num_entries": num_artifacts,
|
|
"artifacts": debug_info,
|
|
},
|
|
)
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_cache_entries",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: debug_str,
|
|
expect_trace_id=False,
|
|
)
|
|
|
|
precompile_cache_entries = {}
|
|
|
|
for key, cache_entry in dynamo_entries.items():
|
|
try:
|
|
backends = cache_entry.backend_ids
|
|
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
|
|
for id_ in backends:
|
|
if id_ not in backend_artifacts:
|
|
debug_str = json.dumps(
|
|
{
|
|
"entry": cache_entry.debug_info,
|
|
"key": key,
|
|
}
|
|
)
|
|
logger.warning("Backend not found")
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_cache_bypass",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: debug_str,
|
|
expect_trace_id=False,
|
|
)
|
|
continue
|
|
artifact = backend_artifacts[id_]
|
|
assert isinstance(artifact, BackendCacheArtifact)
|
|
backend_content[id_] = artifact
|
|
precompile_cache_entries[key] = PrecompileCacheEntry(
|
|
dynamo=cache_entry, backends=backend_content
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to create cache entry %s: %s", key, str(e))
|
|
|
|
error = e
|
|
data = json.dumps(
|
|
{
|
|
"key": key,
|
|
"error": str(error),
|
|
}
|
|
)
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_cache_exception",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: data,
|
|
)
|
|
continue
|
|
return precompile_cache_entries, debug_info
|