[precompile] Implement PrecompileContext for recording precompile artifacts, integrate with CompilePackage (#154415)

This PR implements a basic interface and test for PrecompileContext, a special CacheArtifactManager specifically designed for precompile. The job of a PrecompileContext is to record things precompile needs as torch is compiling,  dump it all into bytes, and then stitch it back together into a cache of callables.

## Why use CacheArtifactManager?
Precompile needs a way to record various serializable data as torch is compiling. CacheArtifactManager already does this today pretty well, handling a lot of serialization and cache information. So we're reusing a bunch of that infrastructure directly.

## How is it different from CacheArtifactManager?
Unlike regular CacheArtifactManager, PrecompileContext needs to be able to take the recorded artifacts and stitch them together after deserialization, to create a single working callable.
Since PrecompileContext doesn't need the cache keys, the "key" field of PrecompileArtifacts can be used for metadata relating to how to stitch the individual functions being compiled together into a full callable. For example, on a given dynamo compile, if there are multiple functions (via graph breaks or recompiles) being compiled, MegaCache would organize it like so:

![image](https://github.com/user-attachments/assets/49a0a75b-1e7f-4d96-8d81-6769fe5a53ca)

Whereas we'd visualize PrecompileContext's result like so:

![image](https://github.com/user-attachments/assets/fcc0dd4e-dfbf-4b13-9c08-2e99b373180b)

For now, we just handle eager mode; in the diff above, I'll hook up the other backend artifacts from PrecompileContext.

After this PR, precompile consists of three main interfaces:

### CompilePackage
- Everything needed to run one torch.compile'd function (including graph breaks)
- `__init__(fn, cache_entry)` Initializes with a DynamoCacheEntry
- `install(backends)` load precompile artifacts into function's dynamo state with a dictionary of backends
- `cache_entry()` return a serializable cache entry to save

### DynamoStore
- Responsible for tracking CompilePackages on disk (and/or in memory)
- `load_package(path)`: load a package given a torch compiled function and a path to the cache artifact
- `save_package(package, path): Save a CompiledPackage to a path. Calls PrecompileContext to grab backend data
- `record_package(package)`: Record a package to PrecompileContext (for global serialization/deserialization)

### PrecompileContext
- Overarching context for serializing and deserializing precompile artifacts. Supports **global** and **local** setups.
- `serialize()`: (Global) serializes all artifacts in PrecompileContext into bytes
- `populate_caches(bytes)`: (Global) takes serialized bytes and puts them into DynamoStore (TODO)
- `serialize_artifact_by_key(key)`: (Local) serialize a single artifact by its cache key

<img width="1455" alt="image" src="https://github.com/user-attachments/assets/99b61330-7607-4763-bdbc-85b366e82cdd" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154415
Approved by: https://github.com/zhxchen17
ghstack dependencies: #155118
This commit is contained in:
James Wu 2025-06-12 14:07:25 -07:00 committed by PyTorch MergeBot
parent b2fc9cfea1
commit 3819584f12
7 changed files with 399 additions and 69 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"]
import os
import pickle
import torch
import torch._dynamo.testing
@ -9,60 +8,18 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage
from torch._dynamo.package import CompilePackage, DynamoStore
from torch._inductor.runtime.runtime_utils import cache_dir
class StorageForTesting:
def __init__(self, path: str):
self.path = path
self.backends = {}
def _write_pickle(self, data, *path: str):
with open(os.path.join(self.path, *path) + ".pickle", "wb") as f:
pickle.dump(data, f)
def write_dynamo(self, dynamo):
self._write_pickle(dynamo, "dynamo")
def write_backend(self, backend_id):
os.makedirs(os.path.join(self.path, backend_id), exist_ok=True)
self._write_pickle(self.backends[backend_id], backend_id, "fx_graph")
def _read_pickle(self, *path):
with open(os.path.join(self.path, *path) + ".pickle", "rb") as f:
return pickle.load(f)
def read_backend(self, backend_id):
return self._read_pickle(backend_id, "fx_graph")
def read_dynamo(self):
return self._read_pickle("dynamo")
def add_backend(self, backend_id, backend):
self.backends[backend_id] = backend
def save_package(self, dynamo_cache_entry):
self.write_dynamo(dynamo_cache_entry)
for backend_id in dynamo_cache_entry.backend_ids:
self.write_backend(backend_id)
def load_package(self):
dynamo = self.read_dynamo()
self.backends = {}
for backend_id in dynamo.backend_ids:
self.backends[backend_id] = self.read_backend(backend_id)
return dynamo
class TestPackage(torch._inductor.test_case.TestCase):
def storage(self):
def path(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return StorageForTesting(path)
return path
def test_basic_fn(self):
storage = self.storage()
ctx = DynamoStore()
def fn(x):
return x + 1
@ -74,8 +31,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
expected = compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
@ -86,13 +43,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
):
compiled_fn(*args)
package = CompilePackage(fn, storage.load_package())
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(storage.backends)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
def test_graph_break_bomb(self):
storage = self.storage()
ctx = DynamoStore()
def fn(x, l, r):
if l > r:
@ -121,8 +78,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
for args in args_list:
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
@ -133,11 +90,11 @@ class TestPackage(torch._inductor.test_case.TestCase):
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package = CompilePackage(fn, storage.load_package())
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(storage.backends)
package.install(backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
@ -148,7 +105,7 @@ class TestPackage(torch._inductor.test_case.TestCase):
compiled_fn(torch.tensor(N), 0, N - 1)
def test_dynamic_shape(self):
storage = self.storage()
ctx = DynamoStore()
def fn(x):
return x + x.shape[0]
@ -165,8 +122,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
@ -177,9 +134,9 @@ class TestPackage(torch._inductor.test_case.TestCase):
):
compiled_fn(*args1)
package = CompilePackage(fn, storage.load_package())
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(storage.backends)
package.install(backends)
self.assertEqual(expected1, compiled_fn(*args1))

View File

@ -0,0 +1,105 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheArtifact,
BundledAOTAutogradCacheEntry,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import requires_triton
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch(
{"bundled_autograd_cache": True}
) # Requires bundledaotautograd cache for now
class PrecompileContextTests(InductorTestCase):
def setUp(self):
"""
Reset all counters and caches before each unit test
"""
super().setUp()
# Clear PrecompileContext cache artifacts
PrecompileContext.clear()
@requires_triton()
def test_basic(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device="cuda", requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
serialized, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
artifacts = PrecompileContext.deserialize(serialized)
assert artifacts is not None
deserialized = artifacts["precompile_aot_autograd"]
assert len(deserialized) == 1
entry = deserialized[0]
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
entry = entry.after_deserialization()
assert isinstance(
entry,
BundledAOTAutogradCacheEntry,
)
# Now that we've serialized, there should be no new cache artifacts
self.assertEqual(
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
)
@requires_triton()
def test_serialize_by_key(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device="cuda", requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
# we have torch._dynamo.package implemented
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, BundledAOTAutogradCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
_, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -14,6 +14,7 @@ import functools
import hashlib
import importlib
import logging
import os
import pickle
import platform
import sys
@ -23,6 +24,8 @@ from typing import Any, NewType, Optional
import torch
import torch._inductor.package
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
from torch.compiler._cache import CacheArtifactFactory
from .bytecode_transformation import get_code_keys
@ -128,6 +131,16 @@ class _DynamoCacheEntry:
return {backend_id for code in self.codes for backend_id in code.backend_ids}
@CacheArtifactFactory.register
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
@staticmethod
def type() -> str:
return "precompile_dynamo"
def after_deserialization(self) -> _DynamoCacheEntry:
return pickle.loads(self.content)
class CompilePackage:
"""
CompilePackage is considered a low level component and should not be directly exposed to
@ -303,6 +316,10 @@ class CompilePackage:
fn = types.FunctionType(code, module.__dict__, function_name)
self._install_global(module, function_name, fn)
for backend_id in entry.backend_ids:
if backend_id not in backends:
raise RuntimeError(
f"Backend {backend_id} is not found in the given backends"
)
backend = backends[backend_id]
self._install_global(
module,
@ -326,6 +343,69 @@ class CompilePackage:
SerializedCode.to_code_object(guarded_code.dynamo_code),
)
def save(self) -> _DynamoCacheEntry:
def cache_entry(self) -> _DynamoCacheEntry:
self.validate()
return _DynamoCacheEntry(codes=list(self._codes.values()))
@CacheArtifactFactory.register
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
@staticmethod
def type() -> str:
return "precompile_eager"
def after_deserialization(self) -> Any:
return pickle.loads(self.content)
class DynamoStore:
"""
A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them.
"""
def record_package(self, package: CompilePackage) -> None:
"""Records a package to PrecompileContext, so that it can be serialized later."""
cache_entry = package.cache_entry()
pickled_result = pickle.dumps(cache_entry)
PrecompileContext.record_artifact(
_DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result
)
def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None:
"""Records eager fx graphs to PrecompileContext for testing purposes."""
pickled_result = pickle.dumps(backend)
PrecompileContext.record_artifact(
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
)
def save_package(self, package: CompilePackage, path: str) -> None:
"""Saves a package to a given path. Grabs backends from PrecompileContext."""
backend_content = {}
cache_entry = package.cache_entry()
for backend_id in cache_entry.backend_ids:
backend_content[backend_id] = PrecompileContext.serialize_artifact_by_key(
backend_id
)
try:
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
pickle.dump(cache_entry, dynamo_path)
with open(os.path.join(path, "backends"), "wb") as backend_path:
pickle.dump(backend_content, backend_path)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
def load_package(
self, fn: Any, path: str
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
"""Loads a package from a given path and returns it plus a list of deserialized backends"""
try:
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
cache_entry = pickle.load(dynamo_path)
with open(os.path.join(path, "backends"), "rb") as backend_path:
backend_content = pickle.load(backend_path)
except Exception as e:
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
for backend_id, backend in backend_content.items():
backend_content[backend_id] = backend.after_deserialization()
package = CompilePackage(fn, cache_entry)
return package, backend_content

View File

@ -0,0 +1,146 @@
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Generic, Optional, TypeVar
from typing_extensions import override
from torch.compiler._cache import (
_serialize_single_cache,
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
CacheArtifactsResult,
CacheInfo,
)
from torch.utils._appending_byte_serializer import AppendingByteSerializer
from torch.utils._ordered_set import OrderedSet
"""
Classes and implementations related to precompile
"""
T = TypeVar("T")
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
"""
Data for each cache artifact that will be serialized and deserialized by
PrecompileContext, rather than CacheArtifactManager.
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
as needed, and use them in after_deserialization.
Example implementation:
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
my_field: int
def after_deserialization(self) -> MySerializableType:
result = pickle.loads(self.content)
# Do some extra work post deserialization
result.my_post_deserialization_function(self.my_field)
return result
"""
@override
def populate_cache(self) -> None:
raise RuntimeError("Precompile cache artifacts do not populate caches")
@override
def precompile_compatible(self) -> bool:
return True
@abstractmethod
def after_deserialization(self) -> T:
"""
Code to be run after reading raw byte contents from disk.
Generally converts self.content from raw bytes back into its original form.
"""
...
class PrecompileContext(CacheArtifactManager):
"""
PrecompileContext is a special CacheArtifactManager for handling precompilation
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
together and place it into a global Precompile Cache.
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
- CodeStateArtifact (from torch._dynamo.package once available)
"""
# Protected by the compile_lock
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
# This allows us to implement serialize_by_key easily.
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
# are transferred to _new_cache_artifacts before serialization.
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
# When serialize() is called, artifacts are transferred from _cache_artifacts to
# internal data structure of the _serializer
# This allows us to only pay the cost of serialization if serialize() is called
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
)
_cache_info: CacheInfo = CacheInfo()
@classmethod
def clear(cls) -> None:
cls._new_cache_artifacts_by_key.clear()
super().clear()
@override
@classmethod
def record_artifact(
cls,
artifact_type: str,
key: str,
content: Any,
) -> None:
"""
Called from each caching operation to record the artifact in this
"mega" list
"""
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
if artifact in cls._seen_artifacts:
return
cls._new_cache_artifacts_by_key[key] = artifact
cls._seen_artifacts.add(artifact)
@classmethod
def _save_artifacts_by_type(cls) -> None:
"""
We normally record artifacts by key, but serialization expects them to be organized
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
"""
for artifact in cls._new_cache_artifacts_by_key.values():
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear()
@classmethod
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
"""
Serialize all artifacts with the given key returned in a list.
"""
return cls._new_cache_artifacts_by_key.get(key, None)
@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
cls._save_artifacts_by_type()
return super().serialize()
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
raise NotImplementedError("TODO")
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
BundledAOTAutogradCacheArtifact,
)

View File

@ -21,6 +21,7 @@ from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Uni
from typing_extensions import override
import torch
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
from torch._dynamo.utils import (
chromium_event_log_active,
@ -916,6 +917,21 @@ class AOTAutogradCacheArtifact(CacheArtifact):
return "aot_autograd"
@CacheArtifactFactory.register
class BundledAOTAutogradCacheArtifact(
PrecompileCacheArtifact[BundledAOTAutogradCacheEntry]
):
@override
@staticmethod
def type():
return "precompile_aot_autograd"
@override
def after_deserialization(self) -> BundledAOTAutogradCacheEntry:
entry = pickle.loads(self.content)
return entry
class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
"""
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
@ -1167,6 +1183,10 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
CacheArtifactManager.record_artifact(
AOTAutogradCacheArtifact.type(), key, pickled_content
)
if config.bundled_autograd_cache:
PrecompileContext.record_artifact(
BundledAOTAutogradCacheArtifact.type(), key, pickled_content
)
except Exception as e:
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
if config.strict_autograd_cache:
@ -1196,6 +1216,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
CacheArtifactManager.record_artifact(
AOTAutogradCacheArtifact.type(), key, content
)
if config.bundled_autograd_cache:
# TODO: the key here isn't correct
PrecompileContext.record_artifact(
BundledAOTAutogradCacheArtifact.type(), key, content
)
AOTAutogradCache._write_to_local_cache(key, content)
counters["aot_autograd"]["autograd_cache_saved"] += 1
except BypassAOTAutogradCache as e:

View File

@ -473,4 +473,7 @@ def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
"""
from ._cache import CacheArtifactManager, CacheInfo
return CacheArtifactManager.deserialize(serialized_artifacts)
artifacts = CacheArtifactManager.deserialize(serialized_artifacts)
if artifacts is not None:
return CacheArtifactManager.populate_caches(artifacts)
return None

View File

@ -48,6 +48,9 @@ class CacheArtifact(ABC):
def populate_cache(self) -> None:
pass
def precompile_compatible(self) -> bool:
return False
@staticmethod
def type() -> str:
"""
@ -128,6 +131,10 @@ class CacheInfo:
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
@property
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
def add(self, artifact: CacheArtifact) -> None:
self.artifacts[artifact.type()].append(artifact.key)
@ -159,6 +166,9 @@ def _deserialize_single_cache(
return artifact_type_key, artifacts
CacheArtifactsResult = dict[str, list[CacheArtifact]]
class CacheArtifactManager:
"""
Lightweight manager class for collecting and processing cache artifacts for
@ -177,7 +187,7 @@ class CacheArtifactManager:
"""
# Protected by the compile_lock
_new_cache_artifacts: defaultdict[str, list[CacheArtifact]] = defaultdict(list)
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
@ -207,7 +217,7 @@ class CacheArtifactManager:
cls._new_cache_artifacts = defaultdict(list)
cls._seen_artifacts = OrderedSet()
cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
cls._cache_info = CacheInfo()
cls._cache_info = cls._cache_info.__class__()
try:
yield
finally:
@ -268,9 +278,9 @@ class CacheArtifactManager:
return None
@staticmethod
def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]:
def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]:
"""
Converts the portable format back into various filesystem caches
Converts the portable format back into CacheArtifacts
"""
try:
CacheArtifactManager._ensure_cache_artifacts_registered()
@ -284,6 +294,10 @@ class CacheArtifactManager:
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
return None
return artifacts
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
info = CacheInfo()
for artifact in chain(*artifacts.values()):
log.debug("writing: %s", artifact)
@ -292,8 +306,8 @@ class CacheArtifactManager:
return info
@staticmethod
def _ensure_cache_artifacts_registered() -> None:
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
"""When deserializing caches in fresh process, we need to ensure that all
cache artifacts are registered in the cache registry. This is done by
simply importing all the cache artifacts already wrapped with register call.