mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Implement PrecompileContext for recording precompile artifacts, integrate with CompilePackage (#154415)
This PR implements a basic interface and test for PrecompileContext, a special CacheArtifactManager specifically designed for precompile. The job of a PrecompileContext is to record things precompile needs as torch is compiling, dump it all into bytes, and then stitch it back together into a cache of callables. ## Why use CacheArtifactManager? Precompile needs a way to record various serializable data as torch is compiling. CacheArtifactManager already does this today pretty well, handling a lot of serialization and cache information. So we're reusing a bunch of that infrastructure directly. ## How is it different from CacheArtifactManager? Unlike regular CacheArtifactManager, PrecompileContext needs to be able to take the recorded artifacts and stitch them together after deserialization, to create a single working callable. Since PrecompileContext doesn't need the cache keys, the "key" field of PrecompileArtifacts can be used for metadata relating to how to stitch the individual functions being compiled together into a full callable. For example, on a given dynamo compile, if there are multiple functions (via graph breaks or recompiles) being compiled, MegaCache would organize it like so:  Whereas we'd visualize PrecompileContext's result like so:  For now, we just handle eager mode; in the diff above, I'll hook up the other backend artifacts from PrecompileContext. After this PR, precompile consists of three main interfaces: ### CompilePackage - Everything needed to run one torch.compile'd function (including graph breaks) - `__init__(fn, cache_entry)` Initializes with a DynamoCacheEntry - `install(backends)` load precompile artifacts into function's dynamo state with a dictionary of backends - `cache_entry()` return a serializable cache entry to save ### DynamoStore - Responsible for tracking CompilePackages on disk (and/or in memory) - `load_package(path)`: load a package given a torch compiled function and a path to the cache artifact - `save_package(package, path): Save a CompiledPackage to a path. Calls PrecompileContext to grab backend data - `record_package(package)`: Record a package to PrecompileContext (for global serialization/deserialization) ### PrecompileContext - Overarching context for serializing and deserializing precompile artifacts. Supports **global** and **local** setups. - `serialize()`: (Global) serializes all artifacts in PrecompileContext into bytes - `populate_caches(bytes)`: (Global) takes serialized bytes and puts them into DynamoStore (TODO) - `serialize_artifact_by_key(key)`: (Local) serialize a single artifact by its cache key <img width="1455" alt="image" src="https://github.com/user-attachments/assets/99b61330-7607-4763-bdbc-85b366e82cdd" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/154415 Approved by: https://github.com/zhxchen17 ghstack dependencies: #155118
This commit is contained in:
parent
b2fc9cfea1
commit
3819584f12
|
|
@ -1,7 +1,6 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
|
|
@ -9,60 +8,18 @@ import torch._inductor.config
|
||||||
import torch._inductor.test_case
|
import torch._inductor.test_case
|
||||||
import torch.onnx.operators
|
import torch.onnx.operators
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch._dynamo.package import CompilePackage
|
from torch._dynamo.package import CompilePackage, DynamoStore
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
|
|
||||||
|
|
||||||
class StorageForTesting:
|
|
||||||
def __init__(self, path: str):
|
|
||||||
self.path = path
|
|
||||||
self.backends = {}
|
|
||||||
|
|
||||||
def _write_pickle(self, data, *path: str):
|
|
||||||
with open(os.path.join(self.path, *path) + ".pickle", "wb") as f:
|
|
||||||
pickle.dump(data, f)
|
|
||||||
|
|
||||||
def write_dynamo(self, dynamo):
|
|
||||||
self._write_pickle(dynamo, "dynamo")
|
|
||||||
|
|
||||||
def write_backend(self, backend_id):
|
|
||||||
os.makedirs(os.path.join(self.path, backend_id), exist_ok=True)
|
|
||||||
self._write_pickle(self.backends[backend_id], backend_id, "fx_graph")
|
|
||||||
|
|
||||||
def _read_pickle(self, *path):
|
|
||||||
with open(os.path.join(self.path, *path) + ".pickle", "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
def read_backend(self, backend_id):
|
|
||||||
return self._read_pickle(backend_id, "fx_graph")
|
|
||||||
|
|
||||||
def read_dynamo(self):
|
|
||||||
return self._read_pickle("dynamo")
|
|
||||||
|
|
||||||
def add_backend(self, backend_id, backend):
|
|
||||||
self.backends[backend_id] = backend
|
|
||||||
|
|
||||||
def save_package(self, dynamo_cache_entry):
|
|
||||||
self.write_dynamo(dynamo_cache_entry)
|
|
||||||
for backend_id in dynamo_cache_entry.backend_ids:
|
|
||||||
self.write_backend(backend_id)
|
|
||||||
|
|
||||||
def load_package(self):
|
|
||||||
dynamo = self.read_dynamo()
|
|
||||||
self.backends = {}
|
|
||||||
for backend_id in dynamo.backend_ids:
|
|
||||||
self.backends[backend_id] = self.read_backend(backend_id)
|
|
||||||
return dynamo
|
|
||||||
|
|
||||||
|
|
||||||
class TestPackage(torch._inductor.test_case.TestCase):
|
class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
def storage(self):
|
def path(self):
|
||||||
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
return StorageForTesting(path)
|
return path
|
||||||
|
|
||||||
def test_basic_fn(self):
|
def test_basic_fn(self):
|
||||||
storage = self.storage()
|
ctx = DynamoStore()
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
@ -74,8 +31,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
||||||
expected = compiled_fn(*args)
|
expected = compiled_fn(*args)
|
||||||
for backend_id, backend in package.cached_backends.items():
|
for backend_id, backend in package.cached_backends.items():
|
||||||
storage.add_backend(backend_id, backend)
|
ctx.record_eager_backend(backend_id, backend)
|
||||||
storage.save_package(package.save())
|
ctx.save_package(package, self.path())
|
||||||
|
|
||||||
# Loading
|
# Loading
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
@ -86,13 +43,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
):
|
):
|
||||||
compiled_fn(*args)
|
compiled_fn(*args)
|
||||||
|
|
||||||
package = CompilePackage(fn, storage.load_package())
|
package, backends = ctx.load_package(fn, self.path())
|
||||||
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
||||||
package.install(storage.backends)
|
package.install(backends)
|
||||||
self.assertEqual(expected, compiled_fn(*args))
|
self.assertEqual(expected, compiled_fn(*args))
|
||||||
|
|
||||||
def test_graph_break_bomb(self):
|
def test_graph_break_bomb(self):
|
||||||
storage = self.storage()
|
ctx = DynamoStore()
|
||||||
|
|
||||||
def fn(x, l, r):
|
def fn(x, l, r):
|
||||||
if l > r:
|
if l > r:
|
||||||
|
|
@ -121,8 +78,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
for args in args_list:
|
for args in args_list:
|
||||||
compiled_fn(*args)
|
compiled_fn(*args)
|
||||||
for backend_id, backend in package.cached_backends.items():
|
for backend_id, backend in package.cached_backends.items():
|
||||||
storage.add_backend(backend_id, backend)
|
ctx.record_eager_backend(backend_id, backend)
|
||||||
storage.save_package(package.save())
|
ctx.save_package(package, self.path())
|
||||||
|
|
||||||
# Loading
|
# Loading
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
@ -133,11 +90,11 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||||
):
|
):
|
||||||
compiled_fn(*args)
|
compiled_fn(*args)
|
||||||
package = CompilePackage(fn, storage.load_package())
|
package, backends = ctx.load_package(fn, self.path())
|
||||||
compiled_fn = torch._dynamo.optimize(
|
compiled_fn = torch._dynamo.optimize(
|
||||||
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
||||||
)(fn)
|
)(fn)
|
||||||
package.install(storage.backends)
|
package.install(backends)
|
||||||
for args in args_list:
|
for args in args_list:
|
||||||
self.assertEqual(compiled_fn(*args), args[0].sum())
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
||||||
|
|
||||||
|
|
@ -148,7 +105,7 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
compiled_fn(torch.tensor(N), 0, N - 1)
|
compiled_fn(torch.tensor(N), 0, N - 1)
|
||||||
|
|
||||||
def test_dynamic_shape(self):
|
def test_dynamic_shape(self):
|
||||||
storage = self.storage()
|
ctx = DynamoStore()
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x + x.shape[0]
|
return x + x.shape[0]
|
||||||
|
|
@ -165,8 +122,8 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
||||||
compiled_fn(*args)
|
compiled_fn(*args)
|
||||||
for backend_id, backend in package.cached_backends.items():
|
for backend_id, backend in package.cached_backends.items():
|
||||||
storage.add_backend(backend_id, backend)
|
ctx.record_eager_backend(backend_id, backend)
|
||||||
storage.save_package(package.save())
|
ctx.save_package(package, self.path())
|
||||||
|
|
||||||
# Loading
|
# Loading
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
@ -177,9 +134,9 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||||
):
|
):
|
||||||
compiled_fn(*args1)
|
compiled_fn(*args1)
|
||||||
|
|
||||||
package = CompilePackage(fn, storage.load_package())
|
package, backends = ctx.load_package(fn, self.path())
|
||||||
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
||||||
package.install(storage.backends)
|
package.install(backends)
|
||||||
|
|
||||||
self.assertEqual(expected1, compiled_fn(*args1))
|
self.assertEqual(expected1, compiled_fn(*args1))
|
||||||
|
|
||||||
|
|
|
||||||
105
test/dynamo/test_precompile_context.py
Normal file
105
test/dynamo/test_precompile_context.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._dynamo
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import torch._functorch
|
||||||
|
from torch._dynamo.precompile_context import PrecompileContext
|
||||||
|
from torch._functorch import config as functorch_config
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import (
|
||||||
|
BundledAOTAutogradCacheArtifact,
|
||||||
|
BundledAOTAutogradCacheEntry,
|
||||||
|
)
|
||||||
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||||
|
from torch.testing._internal.inductor_utils import requires_triton
|
||||||
|
|
||||||
|
|
||||||
|
@functorch_config.patch({"enable_autograd_cache": True})
|
||||||
|
@functorch_config.patch(
|
||||||
|
{"bundled_autograd_cache": True}
|
||||||
|
) # Requires bundledaotautograd cache for now
|
||||||
|
class PrecompileContextTests(InductorTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Reset all counters and caches before each unit test
|
||||||
|
"""
|
||||||
|
super().setUp()
|
||||||
|
# Clear PrecompileContext cache artifacts
|
||||||
|
PrecompileContext.clear()
|
||||||
|
|
||||||
|
@requires_triton()
|
||||||
|
def test_basic(self):
|
||||||
|
"""
|
||||||
|
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def simple_function(x):
|
||||||
|
return x.sin() + x.cos()
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(simple_function)
|
||||||
|
|
||||||
|
# Run the compiled function
|
||||||
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
||||||
|
result = compiled_fn(x)
|
||||||
|
result.sum().backward()
|
||||||
|
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
||||||
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
||||||
|
|
||||||
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||||
|
result = PrecompileContext.serialize()
|
||||||
|
assert result is not None
|
||||||
|
serialized, cache_info = result
|
||||||
|
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
||||||
|
|
||||||
|
artifacts = PrecompileContext.deserialize(serialized)
|
||||||
|
assert artifacts is not None
|
||||||
|
deserialized = artifacts["precompile_aot_autograd"]
|
||||||
|
assert len(deserialized) == 1
|
||||||
|
entry = deserialized[0]
|
||||||
|
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
|
||||||
|
entry = entry.after_deserialization()
|
||||||
|
assert isinstance(
|
||||||
|
entry,
|
||||||
|
BundledAOTAutogradCacheEntry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now that we've serialized, there should be no new cache artifacts
|
||||||
|
self.assertEqual(
|
||||||
|
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_triton()
|
||||||
|
def test_serialize_by_key(self):
|
||||||
|
"""
|
||||||
|
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def simple_function(x):
|
||||||
|
return x.sin() + x.cos()
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(simple_function)
|
||||||
|
|
||||||
|
# Run the compiled function
|
||||||
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
||||||
|
result = compiled_fn(x)
|
||||||
|
result.sum().backward()
|
||||||
|
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
|
||||||
|
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
|
||||||
|
# we have torch._dynamo.package implemented
|
||||||
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
|
||||||
|
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
|
||||||
|
result = PrecompileContext.serialize_artifact_by_key(key)
|
||||||
|
assert isinstance(result, BundledAOTAutogradCacheArtifact)
|
||||||
|
self.assertEqual(result.key, key)
|
||||||
|
|
||||||
|
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||||
|
result = PrecompileContext.serialize()
|
||||||
|
assert result is not None
|
||||||
|
_, cache_info = result
|
||||||
|
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
||||||
|
|
@ -14,6 +14,7 @@ import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -23,6 +24,8 @@ from typing import Any, NewType, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.package
|
import torch._inductor.package
|
||||||
|
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||||
|
from torch.compiler._cache import CacheArtifactFactory
|
||||||
|
|
||||||
from .bytecode_transformation import get_code_keys
|
from .bytecode_transformation import get_code_keys
|
||||||
|
|
||||||
|
|
@ -128,6 +131,16 @@ class _DynamoCacheEntry:
|
||||||
return {backend_id for code in self.codes for backend_id in code.backend_ids}
|
return {backend_id for code in self.codes for backend_id in code.backend_ids}
|
||||||
|
|
||||||
|
|
||||||
|
@CacheArtifactFactory.register
|
||||||
|
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
|
||||||
|
@staticmethod
|
||||||
|
def type() -> str:
|
||||||
|
return "precompile_dynamo"
|
||||||
|
|
||||||
|
def after_deserialization(self) -> _DynamoCacheEntry:
|
||||||
|
return pickle.loads(self.content)
|
||||||
|
|
||||||
|
|
||||||
class CompilePackage:
|
class CompilePackage:
|
||||||
"""
|
"""
|
||||||
CompilePackage is considered a low level component and should not be directly exposed to
|
CompilePackage is considered a low level component and should not be directly exposed to
|
||||||
|
|
@ -303,6 +316,10 @@ class CompilePackage:
|
||||||
fn = types.FunctionType(code, module.__dict__, function_name)
|
fn = types.FunctionType(code, module.__dict__, function_name)
|
||||||
self._install_global(module, function_name, fn)
|
self._install_global(module, function_name, fn)
|
||||||
for backend_id in entry.backend_ids:
|
for backend_id in entry.backend_ids:
|
||||||
|
if backend_id not in backends:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Backend {backend_id} is not found in the given backends"
|
||||||
|
)
|
||||||
backend = backends[backend_id]
|
backend = backends[backend_id]
|
||||||
self._install_global(
|
self._install_global(
|
||||||
module,
|
module,
|
||||||
|
|
@ -326,6 +343,69 @@ class CompilePackage:
|
||||||
SerializedCode.to_code_object(guarded_code.dynamo_code),
|
SerializedCode.to_code_object(guarded_code.dynamo_code),
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self) -> _DynamoCacheEntry:
|
def cache_entry(self) -> _DynamoCacheEntry:
|
||||||
self.validate()
|
self.validate()
|
||||||
return _DynamoCacheEntry(codes=list(self._codes.values()))
|
return _DynamoCacheEntry(codes=list(self._codes.values()))
|
||||||
|
|
||||||
|
|
||||||
|
@CacheArtifactFactory.register
|
||||||
|
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
|
||||||
|
@staticmethod
|
||||||
|
def type() -> str:
|
||||||
|
return "precompile_eager"
|
||||||
|
|
||||||
|
def after_deserialization(self) -> Any:
|
||||||
|
return pickle.loads(self.content)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamoStore:
|
||||||
|
"""
|
||||||
|
A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def record_package(self, package: CompilePackage) -> None:
|
||||||
|
"""Records a package to PrecompileContext, so that it can be serialized later."""
|
||||||
|
cache_entry = package.cache_entry()
|
||||||
|
pickled_result = pickle.dumps(cache_entry)
|
||||||
|
PrecompileContext.record_artifact(
|
||||||
|
_DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None:
|
||||||
|
"""Records eager fx graphs to PrecompileContext for testing purposes."""
|
||||||
|
pickled_result = pickle.dumps(backend)
|
||||||
|
PrecompileContext.record_artifact(
|
||||||
|
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_package(self, package: CompilePackage, path: str) -> None:
|
||||||
|
"""Saves a package to a given path. Grabs backends from PrecompileContext."""
|
||||||
|
backend_content = {}
|
||||||
|
cache_entry = package.cache_entry()
|
||||||
|
for backend_id in cache_entry.backend_ids:
|
||||||
|
backend_content[backend_id] = PrecompileContext.serialize_artifact_by_key(
|
||||||
|
backend_id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
|
||||||
|
pickle.dump(cache_entry, dynamo_path)
|
||||||
|
with open(os.path.join(path, "backends"), "wb") as backend_path:
|
||||||
|
pickle.dump(backend_content, backend_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||||
|
|
||||||
|
def load_package(
|
||||||
|
self, fn: Any, path: str
|
||||||
|
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
|
||||||
|
"""Loads a package from a given path and returns it plus a list of deserialized backends"""
|
||||||
|
try:
|
||||||
|
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
|
||||||
|
cache_entry = pickle.load(dynamo_path)
|
||||||
|
with open(os.path.join(path, "backends"), "rb") as backend_path:
|
||||||
|
backend_content = pickle.load(backend_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
|
||||||
|
for backend_id, backend in backend_content.items():
|
||||||
|
backend_content[backend_id] = backend.after_deserialization()
|
||||||
|
package = CompilePackage(fn, cache_entry)
|
||||||
|
return package, backend_content
|
||||||
|
|
|
||||||
146
torch/_dynamo/precompile_context.py
Normal file
146
torch/_dynamo/precompile_context.py
Normal file
|
|
@ -0,0 +1,146 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from torch.compiler._cache import (
|
||||||
|
_serialize_single_cache,
|
||||||
|
CacheArtifact,
|
||||||
|
CacheArtifactFactory,
|
||||||
|
CacheArtifactManager,
|
||||||
|
CacheArtifactsResult,
|
||||||
|
CacheInfo,
|
||||||
|
)
|
||||||
|
from torch.utils._appending_byte_serializer import AppendingByteSerializer
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Classes and implementations related to precompile
|
||||||
|
"""
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
||||||
|
"""
|
||||||
|
Data for each cache artifact that will be serialized and deserialized by
|
||||||
|
PrecompileContext, rather than CacheArtifactManager.
|
||||||
|
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
|
||||||
|
|
||||||
|
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
|
||||||
|
as needed, and use them in after_deserialization.
|
||||||
|
|
||||||
|
Example implementation:
|
||||||
|
|
||||||
|
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
|
||||||
|
my_field: int
|
||||||
|
|
||||||
|
def after_deserialization(self) -> MySerializableType:
|
||||||
|
result = pickle.loads(self.content)
|
||||||
|
# Do some extra work post deserialization
|
||||||
|
result.my_post_deserialization_function(self.my_field)
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def populate_cache(self) -> None:
|
||||||
|
raise RuntimeError("Precompile cache artifacts do not populate caches")
|
||||||
|
|
||||||
|
@override
|
||||||
|
def precompile_compatible(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def after_deserialization(self) -> T:
|
||||||
|
"""
|
||||||
|
Code to be run after reading raw byte contents from disk.
|
||||||
|
Generally converts self.content from raw bytes back into its original form.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class PrecompileContext(CacheArtifactManager):
|
||||||
|
"""
|
||||||
|
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
||||||
|
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
||||||
|
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
|
||||||
|
together and place it into a global Precompile Cache.
|
||||||
|
|
||||||
|
The following artifact types are supported by PrecompileContext:
|
||||||
|
- BundledAOTAutogradCacheArtifact
|
||||||
|
- CodeStateArtifact (from torch._dynamo.package once available)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Protected by the compile_lock
|
||||||
|
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
|
||||||
|
# This allows us to implement serialize_by_key easily.
|
||||||
|
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
|
||||||
|
# are transferred to _new_cache_artifacts before serialization.
|
||||||
|
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
|
||||||
|
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
|
||||||
|
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
|
||||||
|
# This list will not be cleared between serialize() calls
|
||||||
|
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
||||||
|
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
||||||
|
# internal data structure of the _serializer
|
||||||
|
# This allows us to only pay the cost of serialization if serialize() is called
|
||||||
|
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
|
||||||
|
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||||
|
)
|
||||||
|
_cache_info: CacheInfo = CacheInfo()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear(cls) -> None:
|
||||||
|
cls._new_cache_artifacts_by_key.clear()
|
||||||
|
super().clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
@classmethod
|
||||||
|
def record_artifact(
|
||||||
|
cls,
|
||||||
|
artifact_type: str,
|
||||||
|
key: str,
|
||||||
|
content: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Called from each caching operation to record the artifact in this
|
||||||
|
"mega" list
|
||||||
|
"""
|
||||||
|
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
|
||||||
|
if artifact in cls._seen_artifacts:
|
||||||
|
return
|
||||||
|
cls._new_cache_artifacts_by_key[key] = artifact
|
||||||
|
cls._seen_artifacts.add(artifact)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _save_artifacts_by_type(cls) -> None:
|
||||||
|
"""
|
||||||
|
We normally record artifacts by key, but serialization expects them to be organized
|
||||||
|
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
|
||||||
|
"""
|
||||||
|
for artifact in cls._new_cache_artifacts_by_key.values():
|
||||||
|
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
|
||||||
|
cls._new_cache_artifacts_by_key.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
|
||||||
|
"""
|
||||||
|
Serialize all artifacts with the given key returned in a list.
|
||||||
|
"""
|
||||||
|
return cls._new_cache_artifacts_by_key.get(key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
||||||
|
cls._save_artifacts_by_type()
|
||||||
|
return super().serialize()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
||||||
|
raise NotImplementedError("TODO")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ensure_cache_artifacts_registered(cls) -> None:
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||||
|
BundledAOTAutogradCacheArtifact,
|
||||||
|
)
|
||||||
|
|
@ -21,6 +21,7 @@ from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Uni
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||||
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
|
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
|
||||||
from torch._dynamo.utils import (
|
from torch._dynamo.utils import (
|
||||||
chromium_event_log_active,
|
chromium_event_log_active,
|
||||||
|
|
@ -916,6 +917,21 @@ class AOTAutogradCacheArtifact(CacheArtifact):
|
||||||
return "aot_autograd"
|
return "aot_autograd"
|
||||||
|
|
||||||
|
|
||||||
|
@CacheArtifactFactory.register
|
||||||
|
class BundledAOTAutogradCacheArtifact(
|
||||||
|
PrecompileCacheArtifact[BundledAOTAutogradCacheEntry]
|
||||||
|
):
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def type():
|
||||||
|
return "precompile_aot_autograd"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def after_deserialization(self) -> BundledAOTAutogradCacheEntry:
|
||||||
|
entry = pickle.loads(self.content)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||||
"""
|
"""
|
||||||
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
|
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
|
||||||
|
|
@ -1167,6 +1183,10 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||||
CacheArtifactManager.record_artifact(
|
CacheArtifactManager.record_artifact(
|
||||||
AOTAutogradCacheArtifact.type(), key, pickled_content
|
AOTAutogradCacheArtifact.type(), key, pickled_content
|
||||||
)
|
)
|
||||||
|
if config.bundled_autograd_cache:
|
||||||
|
PrecompileContext.record_artifact(
|
||||||
|
BundledAOTAutogradCacheArtifact.type(), key, pickled_content
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
|
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
|
||||||
if config.strict_autograd_cache:
|
if config.strict_autograd_cache:
|
||||||
|
|
@ -1196,6 +1216,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||||
CacheArtifactManager.record_artifact(
|
CacheArtifactManager.record_artifact(
|
||||||
AOTAutogradCacheArtifact.type(), key, content
|
AOTAutogradCacheArtifact.type(), key, content
|
||||||
)
|
)
|
||||||
|
if config.bundled_autograd_cache:
|
||||||
|
# TODO: the key here isn't correct
|
||||||
|
PrecompileContext.record_artifact(
|
||||||
|
BundledAOTAutogradCacheArtifact.type(), key, content
|
||||||
|
)
|
||||||
AOTAutogradCache._write_to_local_cache(key, content)
|
AOTAutogradCache._write_to_local_cache(key, content)
|
||||||
counters["aot_autograd"]["autograd_cache_saved"] += 1
|
counters["aot_autograd"]["autograd_cache_saved"] += 1
|
||||||
except BypassAOTAutogradCache as e:
|
except BypassAOTAutogradCache as e:
|
||||||
|
|
|
||||||
|
|
@ -473,4 +473,7 @@ def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
|
||||||
"""
|
"""
|
||||||
from ._cache import CacheArtifactManager, CacheInfo
|
from ._cache import CacheArtifactManager, CacheInfo
|
||||||
|
|
||||||
return CacheArtifactManager.deserialize(serialized_artifacts)
|
artifacts = CacheArtifactManager.deserialize(serialized_artifacts)
|
||||||
|
if artifacts is not None:
|
||||||
|
return CacheArtifactManager.populate_caches(artifacts)
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,9 @@ class CacheArtifact(ABC):
|
||||||
def populate_cache(self) -> None:
|
def populate_cache(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def precompile_compatible(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def type() -> str:
|
def type() -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -128,6 +131,10 @@ class CacheInfo:
|
||||||
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||||
|
...
|
||||||
|
|
||||||
def add(self, artifact: CacheArtifact) -> None:
|
def add(self, artifact: CacheArtifact) -> None:
|
||||||
self.artifacts[artifact.type()].append(artifact.key)
|
self.artifacts[artifact.type()].append(artifact.key)
|
||||||
|
|
||||||
|
|
@ -159,6 +166,9 @@ def _deserialize_single_cache(
|
||||||
return artifact_type_key, artifacts
|
return artifact_type_key, artifacts
|
||||||
|
|
||||||
|
|
||||||
|
CacheArtifactsResult = dict[str, list[CacheArtifact]]
|
||||||
|
|
||||||
|
|
||||||
class CacheArtifactManager:
|
class CacheArtifactManager:
|
||||||
"""
|
"""
|
||||||
Lightweight manager class for collecting and processing cache artifacts for
|
Lightweight manager class for collecting and processing cache artifacts for
|
||||||
|
|
@ -177,7 +187,7 @@ class CacheArtifactManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Protected by the compile_lock
|
# Protected by the compile_lock
|
||||||
_new_cache_artifacts: defaultdict[str, list[CacheArtifact]] = defaultdict(list)
|
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
|
||||||
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
|
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
|
||||||
# This list will not be cleared between serialize() calls
|
# This list will not be cleared between serialize() calls
|
||||||
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
||||||
|
|
@ -207,7 +217,7 @@ class CacheArtifactManager:
|
||||||
cls._new_cache_artifacts = defaultdict(list)
|
cls._new_cache_artifacts = defaultdict(list)
|
||||||
cls._seen_artifacts = OrderedSet()
|
cls._seen_artifacts = OrderedSet()
|
||||||
cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||||
cls._cache_info = CacheInfo()
|
cls._cache_info = cls._cache_info.__class__()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -268,9 +278,9 @@ class CacheArtifactManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]:
|
def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]:
|
||||||
"""
|
"""
|
||||||
Converts the portable format back into various filesystem caches
|
Converts the portable format back into CacheArtifacts
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
CacheArtifactManager._ensure_cache_artifacts_registered()
|
CacheArtifactManager._ensure_cache_artifacts_registered()
|
||||||
|
|
@ -284,6 +294,10 @@ class CacheArtifactManager:
|
||||||
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
|
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return artifacts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
||||||
info = CacheInfo()
|
info = CacheInfo()
|
||||||
for artifact in chain(*artifacts.values()):
|
for artifact in chain(*artifacts.values()):
|
||||||
log.debug("writing: %s", artifact)
|
log.debug("writing: %s", artifact)
|
||||||
|
|
@ -292,8 +306,8 @@ class CacheArtifactManager:
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _ensure_cache_artifacts_registered() -> None:
|
def _ensure_cache_artifacts_registered(cls) -> None:
|
||||||
"""When deserializing caches in fresh process, we need to ensure that all
|
"""When deserializing caches in fresh process, we need to ensure that all
|
||||||
cache artifacts are registered in the cache registry. This is done by
|
cache artifacts are registered in the cache registry. This is done by
|
||||||
simply importing all the cache artifacts already wrapped with register call.
|
simply importing all the cache artifacts already wrapped with register call.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user