Megacache integration (#163533)

This diff adds megacache integration for DynamoCache.

Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`.

Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533
Approved by: https://github.com/oulgen, https://github.com/zhxchen17
This commit is contained in:
James Wu 2025-10-15 10:42:44 -07:00 committed by PyTorch MergeBot
parent 53f9ae0e50
commit b54e466fd0
4 changed files with 165 additions and 9 deletions

View File

@ -16,6 +16,8 @@ from unittest import mock
import torch import torch
from torch._dynamo import reset from torch._dynamo import reset
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._functorch import config as functorch_config from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
@ -243,8 +245,12 @@ class TestFxGraphCache(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
counters.clear() counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
AOTAutogradCache.clear()
PatchCaches.setUp() PatchCaches.setUp()
CacheArtifactManager.clear() CacheArtifactManager.clear()
torch._dynamo.reset()
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
@ -252,6 +258,8 @@ class TestFxGraphCache(TestCase):
def reset(self): def reset(self):
AOTAutogradCache.clear() AOTAutogradCache.clear()
DynamoCache.clear()
PrecompileContext.clear()
PyCodeCache.cache_clear(purge=True) PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset() torch._dynamo.reset()
clear_caches() clear_caches()
@ -595,6 +603,109 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_triton()
@config.patch(
{
"fx_graph_cache": True,
"fx_graph_remote_cache": False,
"autotune_local_cache": True,
}
)
@torch._dynamo.config.patch(
{
"caching_precompile": True,
}
)
@parametrize("dynamic", (False, True))
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic):
"""
Verify that we can populate and hot load functions from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
# Record artifacts
with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache.
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
autotune_expect = 2 if device == GPU_TYPE else 0
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet
with fresh_cache():
eager_result = fn(a, b)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"):
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1)
@config.patch( @config.patch(
{ {
"fx_graph_cache": True, "fx_graph_cache": True,

View File

@ -34,7 +34,7 @@ from torch._dynamo.exc import PackageError
from torch._dynamo.graph_utils import _graph_device_type from torch._dynamo.graph_utils import _graph_device_type
from .bytecode_transformation import get_code_keys from .bytecode_transformation import get_code_keys
from .utils import dynamo_timed, increment_frame from .utils import counters, dynamo_timed, increment_frame
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -433,6 +433,23 @@ class _DynamoCacheEntry:
} }
from torch.compiler._cache import (
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
)
@CacheArtifactFactory.register
class PrecompileCacheArtifact(CacheArtifact):
def populate_cache(self) -> None:
DynamoCache._write_to_local_cache(self.content, self.key)
@staticmethod
def type() -> str:
return "precompile"
@dataclasses.dataclass @dataclasses.dataclass
class PrecompileCacheEntry: class PrecompileCacheEntry:
""" """
@ -1026,14 +1043,17 @@ class DiskDynamoStore(DynamoStore):
Args: Args:
path_prefix: Prefix directory for where to put CompilePackages on disk path_prefix: Prefix directory for where to put CompilePackages on disk
""" """
self.path_prefix = path_prefix self._path_prefix = path_prefix
def path_prefix(self) -> str:
return self._path_prefix
def clear(self) -> None: def clear(self) -> None:
""" """
Clear all CompilePackages from disk. Clear all CompilePackages from disk.
""" """
if self.path_prefix: if self.path_prefix():
shutil.rmtree(self.path_prefix, ignore_errors=True) shutil.rmtree(self.path_prefix(), ignore_errors=True)
def write( def write(
self, self,
@ -1043,12 +1063,21 @@ class DiskDynamoStore(DynamoStore):
""" """
Write dynamo cache entry and backends to disk. Write dynamo cache entry and backends to disk.
""" """
try:
pickled_content: bytes = pickle.dumps(entry)
CacheArtifactManager.record_artifact(
PrecompileCacheArtifact.type(), path, pickled_content
)
self._write_to_local_cache(pickled_content, path)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None:
from torch._inductor.codecache import write_atomic from torch._inductor.codecache import write_atomic
path = os.path.join(self.path_prefix, path) if self.path_prefix else path path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
pickled_content: bytes = pickle.dumps(entry)
write_atomic(os.path.join(path, "entry"), pickled_content) write_atomic(os.path.join(path, "entry"), pickled_content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e raise RuntimeError(f"Failed to save package to {path}: {e}") from e
@ -1057,7 +1086,7 @@ class DiskDynamoStore(DynamoStore):
""" """
Read dynamo cache entry and backends from disk. Read dynamo cache entry and backends from disk.
""" """
path = os.path.join(self.path_prefix, path) if self.path_prefix else path path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
try: try:
with open(os.path.join(path, "entry"), "rb") as f: with open(os.path.join(path, "entry"), "rb") as f:
pickled_content = f.read() pickled_content = f.read()
@ -1087,15 +1116,18 @@ class DiskDynamoCache(DiskDynamoStore):
""" """
key = CompilePackage.source_id_from_fn(fn) key = CompilePackage.source_id_from_fn(fn)
logger.info("Loading CompilePackage for %s", key) logger.info("Loading CompilePackage for %s", key)
path = os.path.join(self.path_prefix, key) path = os.path.join(self.path_prefix(), key)
if os.path.exists(path): if os.path.exists(path):
try: try:
result = super().load_cache_entry(key) result = super().load_cache_entry(key)
counters["dynamo_cache"]["dynamo_cache_hit"] += 1
return result return result
except Exception as e: except Exception as e:
counters["dynamo_cache"]["dynamo_cache_error"] += 1
logger.warning("Failed to load package from path %s: %s", path, str(e)) logger.warning("Failed to load package from path %s: %s", path, str(e))
return None return None
logger.info("No package found for %s", key) logger.info("No package found for %s", key)
counters["dynamo_cache"]["dynamo_cache_miss"] += 1
return None return None
def load_and_install_package( def load_and_install_package(
@ -1112,6 +1144,9 @@ class DiskDynamoCache(DiskDynamoStore):
package.install(results.backends) package.install(results.backends)
return package return package
def path_prefix(self) -> str:
return os.path.join(cache_dir(), "dynamo")
def cache_dir() -> str: def cache_dir() -> str:
from torch._inductor.runtime.cache_dir_utils import cache_dir from torch._inductor.runtime.cache_dir_utils import cache_dir

View File

@ -501,7 +501,12 @@ def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
- Execute torch.compile - Execute torch.compile
- Call torch.compiler.save_cache_artifacts() - Call torch.compiler.save_cache_artifacts()
""" """
from ._cache import CacheArtifactManager, CacheInfo from ._cache import CacheArtifactManager
if torch._dynamo.config.caching_precompile:
from torch._dynamo.precompile_context import PrecompileContext
PrecompileContext.save_to_dynamo_cache()
return CacheArtifactManager.serialize() return CacheArtifactManager.serialize()

View File

@ -130,6 +130,10 @@ class CacheInfo:
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
... ...
@property
def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
def add(self, artifact: CacheArtifact) -> None: def add(self, artifact: CacheArtifact) -> None:
self.artifacts[artifact.type()].append(artifact.key) self.artifacts[artifact.type()].append(artifact.key)
@ -307,6 +311,7 @@ class CacheArtifactManager:
cache artifacts are registered in the cache registry. This is done by cache artifacts are registered in the cache registry. This is done by
simply importing all the cache artifacts already wrapped with register call. simply importing all the cache artifacts already wrapped with register call.
""" """
from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401 from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
AOTAutogradCacheArtifact, AOTAutogradCacheArtifact,