mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Megacache integration (#163533)
This diff adds megacache integration for DynamoCache. Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`. Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533 Approved by: https://github.com/oulgen, https://github.com/zhxchen17
This commit is contained in:
parent
53f9ae0e50
commit
b54e466fd0
|
|
@ -16,6 +16,8 @@ from unittest import mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._dynamo import reset
|
from torch._dynamo import reset
|
||||||
|
from torch._dynamo.package import DynamoCache
|
||||||
|
from torch._dynamo.precompile_context import PrecompileContext
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._functorch import config as functorch_config
|
from torch._functorch import config as functorch_config
|
||||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||||
|
|
@ -243,8 +245,12 @@ class TestFxGraphCache(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
DynamoCache.clear()
|
||||||
|
PrecompileContext.clear()
|
||||||
|
AOTAutogradCache.clear()
|
||||||
PatchCaches.setUp()
|
PatchCaches.setUp()
|
||||||
CacheArtifactManager.clear()
|
CacheArtifactManager.clear()
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
@ -252,6 +258,8 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
AOTAutogradCache.clear()
|
AOTAutogradCache.clear()
|
||||||
|
DynamoCache.clear()
|
||||||
|
PrecompileContext.clear()
|
||||||
PyCodeCache.cache_clear(purge=True)
|
PyCodeCache.cache_clear(purge=True)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_caches()
|
clear_caches()
|
||||||
|
|
@ -595,6 +603,109 @@ class TestFxGraphCache(TestCase):
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||||
|
|
||||||
|
@requires_triton()
|
||||||
|
@config.patch(
|
||||||
|
{
|
||||||
|
"fx_graph_cache": True,
|
||||||
|
"fx_graph_remote_cache": False,
|
||||||
|
"autotune_local_cache": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@torch._dynamo.config.patch(
|
||||||
|
{
|
||||||
|
"caching_precompile": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@parametrize("dynamic", (False, True))
|
||||||
|
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||||
|
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||||
|
def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic):
|
||||||
|
"""
|
||||||
|
Verify that we can populate and hot load functions from the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if device == GPU_TYPE and not HAS_GPU:
|
||||||
|
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||||
|
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||||
|
raise unittest.SkipTest("requires SM80 or later")
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return x.sin() @ y
|
||||||
|
|
||||||
|
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||||
|
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Record artifacts
|
||||||
|
with fresh_cache():
|
||||||
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
|
|
||||||
|
# A first call should miss in the cache.
|
||||||
|
eager_result = fn(a, b)
|
||||||
|
compiled_result = compiled_fn(a, b)
|
||||||
|
compiled_result.sum().backward()
|
||||||
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
|
||||||
|
|
||||||
|
artifacts = torch.compiler.save_cache_artifacts()
|
||||||
|
|
||||||
|
self.assertIsNotNone(artifacts)
|
||||||
|
|
||||||
|
artifact_bytes, cache_info = artifacts
|
||||||
|
|
||||||
|
autotune_expect = 2 if device == GPU_TYPE else 0
|
||||||
|
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||||
|
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||||
|
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||||
|
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||||
|
self.assertEqual(len(cache_info.precompile_artifacts), 1)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Clean triton kernels
|
||||||
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
|
# We did not load anything so dont hit yet
|
||||||
|
with fresh_cache():
|
||||||
|
eager_result = fn(a, b)
|
||||||
|
# With caching precompile, we have to re torch.compile the function
|
||||||
|
# to trigger cache lookup
|
||||||
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
|
compiled_result = compiled_fn(a, b)
|
||||||
|
compiled_result.sum().backward()
|
||||||
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
|
||||||
|
self.reset()
|
||||||
|
# Clean triton kernels
|
||||||
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
|
# Hot load and hit
|
||||||
|
with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"):
|
||||||
|
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
|
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||||
|
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||||
|
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||||
|
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||||
|
self.assertEqual(len(cache_info.precompile_artifacts), 1)
|
||||||
|
|
||||||
|
# With caching precompile, we have to re torch.compile the function
|
||||||
|
# to trigger cache lookup
|
||||||
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
|
|
||||||
|
eager_result = fn(a, b)
|
||||||
|
compiled_result = compiled_fn(a, b)
|
||||||
|
compiled_result.sum().backward()
|
||||||
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1)
|
||||||
|
|
||||||
@config.patch(
|
@config.patch(
|
||||||
{
|
{
|
||||||
"fx_graph_cache": True,
|
"fx_graph_cache": True,
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from torch._dynamo.exc import PackageError
|
||||||
from torch._dynamo.graph_utils import _graph_device_type
|
from torch._dynamo.graph_utils import _graph_device_type
|
||||||
|
|
||||||
from .bytecode_transformation import get_code_keys
|
from .bytecode_transformation import get_code_keys
|
||||||
from .utils import dynamo_timed, increment_frame
|
from .utils import counters, dynamo_timed, increment_frame
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -433,6 +433,23 @@ class _DynamoCacheEntry:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
from torch.compiler._cache import (
|
||||||
|
CacheArtifact,
|
||||||
|
CacheArtifactFactory,
|
||||||
|
CacheArtifactManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@CacheArtifactFactory.register
|
||||||
|
class PrecompileCacheArtifact(CacheArtifact):
|
||||||
|
def populate_cache(self) -> None:
|
||||||
|
DynamoCache._write_to_local_cache(self.content, self.key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def type() -> str:
|
||||||
|
return "precompile"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PrecompileCacheEntry:
|
class PrecompileCacheEntry:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1026,14 +1043,17 @@ class DiskDynamoStore(DynamoStore):
|
||||||
Args:
|
Args:
|
||||||
path_prefix: Prefix directory for where to put CompilePackages on disk
|
path_prefix: Prefix directory for where to put CompilePackages on disk
|
||||||
"""
|
"""
|
||||||
self.path_prefix = path_prefix
|
self._path_prefix = path_prefix
|
||||||
|
|
||||||
|
def path_prefix(self) -> str:
|
||||||
|
return self._path_prefix
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clear all CompilePackages from disk.
|
Clear all CompilePackages from disk.
|
||||||
"""
|
"""
|
||||||
if self.path_prefix:
|
if self.path_prefix():
|
||||||
shutil.rmtree(self.path_prefix, ignore_errors=True)
|
shutil.rmtree(self.path_prefix(), ignore_errors=True)
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1043,12 +1063,21 @@ class DiskDynamoStore(DynamoStore):
|
||||||
"""
|
"""
|
||||||
Write dynamo cache entry and backends to disk.
|
Write dynamo cache entry and backends to disk.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
pickled_content: bytes = pickle.dumps(entry)
|
||||||
|
CacheArtifactManager.record_artifact(
|
||||||
|
PrecompileCacheArtifact.type(), path, pickled_content
|
||||||
|
)
|
||||||
|
self._write_to_local_cache(pickled_content, path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||||
|
|
||||||
|
def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None:
|
||||||
from torch._inductor.codecache import write_atomic
|
from torch._inductor.codecache import write_atomic
|
||||||
|
|
||||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
|
||||||
try:
|
try:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
pickled_content: bytes = pickle.dumps(entry)
|
|
||||||
write_atomic(os.path.join(path, "entry"), pickled_content)
|
write_atomic(os.path.join(path, "entry"), pickled_content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||||
|
|
@ -1057,7 +1086,7 @@ class DiskDynamoStore(DynamoStore):
|
||||||
"""
|
"""
|
||||||
Read dynamo cache entry and backends from disk.
|
Read dynamo cache entry and backends from disk.
|
||||||
"""
|
"""
|
||||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(path, "entry"), "rb") as f:
|
with open(os.path.join(path, "entry"), "rb") as f:
|
||||||
pickled_content = f.read()
|
pickled_content = f.read()
|
||||||
|
|
@ -1087,15 +1116,18 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||||
"""
|
"""
|
||||||
key = CompilePackage.source_id_from_fn(fn)
|
key = CompilePackage.source_id_from_fn(fn)
|
||||||
logger.info("Loading CompilePackage for %s", key)
|
logger.info("Loading CompilePackage for %s", key)
|
||||||
path = os.path.join(self.path_prefix, key)
|
path = os.path.join(self.path_prefix(), key)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
try:
|
try:
|
||||||
result = super().load_cache_entry(key)
|
result = super().load_cache_entry(key)
|
||||||
|
counters["dynamo_cache"]["dynamo_cache_hit"] += 1
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
counters["dynamo_cache"]["dynamo_cache_error"] += 1
|
||||||
logger.warning("Failed to load package from path %s: %s", path, str(e))
|
logger.warning("Failed to load package from path %s: %s", path, str(e))
|
||||||
return None
|
return None
|
||||||
logger.info("No package found for %s", key)
|
logger.info("No package found for %s", key)
|
||||||
|
counters["dynamo_cache"]["dynamo_cache_miss"] += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_and_install_package(
|
def load_and_install_package(
|
||||||
|
|
@ -1112,6 +1144,9 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||||
package.install(results.backends)
|
package.install(results.backends)
|
||||||
return package
|
return package
|
||||||
|
|
||||||
|
def path_prefix(self) -> str:
|
||||||
|
return os.path.join(cache_dir(), "dynamo")
|
||||||
|
|
||||||
|
|
||||||
def cache_dir() -> str:
|
def cache_dir() -> str:
|
||||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||||
|
|
|
||||||
|
|
@ -501,7 +501,12 @@ def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
|
||||||
- Execute torch.compile
|
- Execute torch.compile
|
||||||
- Call torch.compiler.save_cache_artifacts()
|
- Call torch.compiler.save_cache_artifacts()
|
||||||
"""
|
"""
|
||||||
from ._cache import CacheArtifactManager, CacheInfo
|
from ._cache import CacheArtifactManager
|
||||||
|
|
||||||
|
if torch._dynamo.config.caching_precompile:
|
||||||
|
from torch._dynamo.precompile_context import PrecompileContext
|
||||||
|
|
||||||
|
PrecompileContext.save_to_dynamo_cache()
|
||||||
|
|
||||||
return CacheArtifactManager.serialize()
|
return CacheArtifactManager.serialize()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,10 @@ class CacheInfo:
|
||||||
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||||
|
...
|
||||||
|
|
||||||
def add(self, artifact: CacheArtifact) -> None:
|
def add(self, artifact: CacheArtifact) -> None:
|
||||||
self.artifacts[artifact.type()].append(artifact.key)
|
self.artifacts[artifact.type()].append(artifact.key)
|
||||||
|
|
||||||
|
|
@ -307,6 +311,7 @@ class CacheArtifactManager:
|
||||||
cache artifacts are registered in the cache registry. This is done by
|
cache artifacts are registered in the cache registry. This is done by
|
||||||
simply importing all the cache artifacts already wrapped with register call.
|
simply importing all the cache artifacts already wrapped with register call.
|
||||||
"""
|
"""
|
||||||
|
from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
|
||||||
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
|
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
|
||||||
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||||
AOTAutogradCacheArtifact,
|
AOTAutogradCacheArtifact,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user