pytorch/test/dynamo/test_precompile_context.py
James Wu 3819584f12 [precompile] Implement PrecompileContext for recording precompile artifacts, integrate with CompilePackage (#154415)
This PR implements a basic interface and test for PrecompileContext, a special CacheArtifactManager specifically designed for precompile. The job of a PrecompileContext is to record things precompile needs as torch is compiling,  dump it all into bytes, and then stitch it back together into a cache of callables.

## Why use CacheArtifactManager?
Precompile needs a way to record various serializable data as torch is compiling. CacheArtifactManager already does this today pretty well, handling a lot of serialization and cache information. So we're reusing a bunch of that infrastructure directly.

## How is it different from CacheArtifactManager?
Unlike regular CacheArtifactManager, PrecompileContext needs to be able to take the recorded artifacts and stitch them together after deserialization, to create a single working callable.
Since PrecompileContext doesn't need the cache keys, the "key" field of PrecompileArtifacts can be used for metadata relating to how to stitch the individual functions being compiled together into a full callable. For example, on a given dynamo compile, if there are multiple functions (via graph breaks or recompiles) being compiled, MegaCache would organize it like so:

![image](https://github.com/user-attachments/assets/49a0a75b-1e7f-4d96-8d81-6769fe5a53ca)

Whereas we'd visualize PrecompileContext's result like so:

![image](https://github.com/user-attachments/assets/fcc0dd4e-dfbf-4b13-9c08-2e99b373180b)

For now, we just handle eager mode; in the diff above, I'll hook up the other backend artifacts from PrecompileContext.

After this PR, precompile consists of three main interfaces:

### CompilePackage
- Everything needed to run one torch.compile'd function (including graph breaks)
- `__init__(fn, cache_entry)` Initializes with a DynamoCacheEntry
- `install(backends)` load precompile artifacts into function's dynamo state with a dictionary of backends
- `cache_entry()` return a serializable cache entry to save

### DynamoStore
- Responsible for tracking CompilePackages on disk (and/or in memory)
- `load_package(path)`: load a package given a torch compiled function and a path to the cache artifact
- `save_package(package, path): Save a CompiledPackage to a path. Calls PrecompileContext to grab backend data
- `record_package(package)`: Record a package to PrecompileContext (for global serialization/deserialization)

### PrecompileContext
- Overarching context for serializing and deserializing precompile artifacts. Supports **global** and **local** setups.
- `serialize()`: (Global) serializes all artifacts in PrecompileContext into bytes
- `populate_caches(bytes)`: (Global) takes serialized bytes and puts them into DynamoStore (TODO)
- `serialize_artifact_by_key(key)`: (Local) serialize a single artifact by its cache key

<img width="1455" alt="image" src="https://github.com/user-attachments/assets/99b61330-7607-4763-bdbc-85b366e82cdd" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154415
Approved by: https://github.com/zhxchen17
ghstack dependencies: #155118
2025-06-13 14:11:24 +00:00

106 lines
3.8 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheArtifact,
BundledAOTAutogradCacheEntry,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import requires_triton
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch(
{"bundled_autograd_cache": True}
) # Requires bundledaotautograd cache for now
class PrecompileContextTests(InductorTestCase):
def setUp(self):
"""
Reset all counters and caches before each unit test
"""
super().setUp()
# Clear PrecompileContext cache artifacts
PrecompileContext.clear()
@requires_triton()
def test_basic(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device="cuda", requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
serialized, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
artifacts = PrecompileContext.deserialize(serialized)
assert artifacts is not None
deserialized = artifacts["precompile_aot_autograd"]
assert len(deserialized) == 1
entry = deserialized[0]
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
entry = entry.after_deserialization()
assert isinstance(
entry,
BundledAOTAutogradCacheEntry,
)
# Now that we've serialized, there should be no new cache artifacts
self.assertEqual(
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
)
@requires_triton()
def test_serialize_by_key(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
compiled_fn = torch.compile(simple_function)
# Run the compiled function
x = torch.randn(10, device="cuda", requires_grad=True)
result = compiled_fn(x)
result.sum().backward()
# Check that PrecompileContext._new_cache_artifacts_by_key has length 1
# TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
# we have torch._dynamo.package implemented
self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, BundledAOTAutogradCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
_, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()