mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
While using save_cache_artifacts on internal workloads, we have noticed that repeatedly calling this function after every batch is incredibly expensive. This PR significantly speeds up this function call by opting out of pickle and redesigning serialization algorithm. Essentially what we want is to be able to call serialize many times without incurring costs from scratch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148227 Approved by: https://github.com/jamesjwu ghstack dependencies: #148226
212 lines
7.4 KiB
Python
212 lines
7.4 KiB
Python
import copy
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
from enum import Enum
|
|
from typing import Optional, Union
|
|
|
|
from torch._inductor.remote_cache import JsonDataTy, RemoteCacheJsonSerde
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch.utils._appending_byte_serializer import (
|
|
AppendingByteSerializer,
|
|
BytesReader,
|
|
BytesWriter,
|
|
)
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CacheArtifactType(Enum):
|
|
"""
|
|
Type of cache
|
|
"""
|
|
|
|
INDUCTOR = 0
|
|
AUTOTUNE = 1
|
|
AOT_AUTOGRAD = 2
|
|
PGO = 3
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CacheArtifact:
|
|
"""
|
|
Data for each cache artifact that will be serialized and deserialized
|
|
"""
|
|
|
|
type: CacheArtifactType
|
|
key: str
|
|
content: bytes = dataclasses.field(repr=False) # Do not display potential binary
|
|
|
|
@staticmethod
|
|
def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None:
|
|
writer.write_uint64(cls.type.value)
|
|
writer.write_str(cls.key)
|
|
writer.write_bytes(cls.content)
|
|
|
|
@staticmethod
|
|
def deserialize(reader: BytesReader) -> "CacheArtifact":
|
|
type = reader.read_uint64()
|
|
key = reader.read_str()
|
|
content = reader.read_bytes()
|
|
return CacheArtifact(CacheArtifactType(type), key, content)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CacheInfo:
|
|
"""
|
|
Return value of serialization and deserialization for the purpose of
|
|
instrumentation
|
|
"""
|
|
|
|
inductor_artifacts: list[str] = dataclasses.field(default_factory=list)
|
|
autotune_artifacts: list[str] = dataclasses.field(default_factory=list)
|
|
aot_autograd_artifacts: list[str] = dataclasses.field(default_factory=list)
|
|
pgo_artifacts: list[str] = dataclasses.field(default_factory=list)
|
|
|
|
def add(self, artifact: CacheArtifact) -> None:
|
|
if artifact.type == CacheArtifactType.INDUCTOR:
|
|
self.inductor_artifacts.append(artifact.key)
|
|
elif artifact.type == CacheArtifactType.AUTOTUNE:
|
|
self.autotune_artifacts.append(artifact.key)
|
|
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
|
|
self.aot_autograd_artifacts.append(artifact.key)
|
|
elif artifact.type == CacheArtifactType.PGO:
|
|
self.pgo_artifacts.append(artifact.key)
|
|
else:
|
|
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004
|
|
|
|
def clear(self) -> None:
|
|
self.inductor_artifacts.clear()
|
|
self.autotune_artifacts.clear()
|
|
self.aot_autograd_artifacts.clear()
|
|
self.pgo_artifacts.clear()
|
|
|
|
|
|
class CacheArtifactManager:
|
|
"""
|
|
Lightweight manager class for collecting and processing cache artifacts for
|
|
hot loading
|
|
|
|
Intended Lifecycle:
|
|
- Execute code via torch.compile, this will call
|
|
CacheArtifactManager.record_artifact on each cache artifact
|
|
- Call CacheArtifactManager.serialize to convert all the cache artifacts
|
|
to portable format
|
|
- Call CacheArtifactManager.deserialize to hot load the cache artifacts on
|
|
a potentially different process
|
|
|
|
NOTE: There's no FB/FC guarentees, results of cache artifacts will not be
|
|
used unless code version matches.
|
|
"""
|
|
|
|
# Protected by the compile_lock
|
|
_new_cache_artifacts: list[CacheArtifact] = []
|
|
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
|
|
# This list will not be cleared between serialize() calls
|
|
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
|
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
|
# internal data structure of the _serializer
|
|
# This allows us to only pay the cost of serialization if serialize() is called
|
|
_serializer: AppendingByteSerializer[CacheArtifact] = AppendingByteSerializer(
|
|
serialize_fn=CacheArtifact.serialize
|
|
)
|
|
_cache_info: CacheInfo = CacheInfo()
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls._new_cache_artifacts.clear()
|
|
cls._seen_artifacts.clear()
|
|
cls._serializer.clear()
|
|
cls._cache_info.clear()
|
|
|
|
@classmethod
|
|
def record_artifact(
|
|
cls,
|
|
artifact_type: CacheArtifactType,
|
|
key: str,
|
|
content: Union[bytes, JsonDataTy],
|
|
) -> None:
|
|
"""
|
|
Called from each caching operation to record the artifact in this
|
|
"mega" list
|
|
"""
|
|
if artifact_type == CacheArtifactType.AUTOTUNE:
|
|
assert not isinstance(content, bytes)
|
|
serde = RemoteCacheJsonSerde()
|
|
content = serde.encode(content)
|
|
assert isinstance(content, bytes)
|
|
artifact = CacheArtifact(artifact_type, key, content)
|
|
if artifact in cls._seen_artifacts:
|
|
return
|
|
log.debug("Recording %s", str(artifact))
|
|
cls._new_cache_artifacts.append(artifact)
|
|
cls._seen_artifacts.add(artifact)
|
|
|
|
@classmethod
|
|
def need_serialize(cls) -> bool:
|
|
"""
|
|
Have we seen new artifacts since last serialize call?
|
|
"""
|
|
return len(cls._new_cache_artifacts) != 0
|
|
|
|
@classmethod
|
|
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
|
"""
|
|
Converts the "mega" list into portable format
|
|
"""
|
|
for artifact in cls._new_cache_artifacts:
|
|
log.debug("saving: %s", artifact)
|
|
cls._cache_info.add(artifact)
|
|
try:
|
|
# We deep copy cls._cache_info since later compilations
|
|
# can keep adding to cache_info
|
|
info = copy.deepcopy(cls._cache_info)
|
|
cls._serializer.extend(cls._new_cache_artifacts)
|
|
artifact_bytes = cls._serializer.to_bytes()
|
|
cls._new_cache_artifacts.clear()
|
|
return artifact_bytes, info
|
|
except Exception:
|
|
log.warning("Failed to pickle cache artifacts", exc_info=True)
|
|
return None
|
|
|
|
@staticmethod
|
|
def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]:
|
|
"""
|
|
Converts the portable format back into various filesystem caches
|
|
"""
|
|
try:
|
|
artifacts = AppendingByteSerializer.to_list(
|
|
serialized_artifacts, deserialize_fn=CacheArtifact.deserialize
|
|
)
|
|
except Exception:
|
|
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
|
|
return None
|
|
|
|
from torch._dynamo.pgo import write_local_impl
|
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
|
from torch._inductor.codecache import FxGraphCache
|
|
from torch._inductor.runtime.autotune_cache import _LocalAutotuneCacheBackend
|
|
|
|
autotune_cache = _LocalAutotuneCacheBackend()
|
|
|
|
info = CacheInfo()
|
|
for artifact in artifacts:
|
|
log.debug("writing: %s", artifact)
|
|
info.add(artifact)
|
|
|
|
if artifact.type == CacheArtifactType.INDUCTOR:
|
|
FxGraphCache._write_to_local_cache(artifact.key, artifact.content)
|
|
elif artifact.type == CacheArtifactType.AUTOTUNE:
|
|
key = os.path.join(cache_dir(), artifact.key)
|
|
autotune_cache._put(key, artifact.content)
|
|
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
|
|
AOTAutogradCache._write_to_local_cache(artifact.key, artifact.content)
|
|
elif artifact.type == CacheArtifactType.PGO:
|
|
meta = write_local_impl(artifact.key, artifact.content)
|
|
assert meta is not None
|
|
else:
|
|
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004
|
|
return info
|