pytorch/test/dynamo/test_package.py
James Wu bfe9e60ffb Simplify PrecompileContext to no longer be a CacheArtifactManager (#162886)
Summary:
This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache.

This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries.

The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries().

Saving single entries and such with DynamoCache still works normally.

Test Plan:
All existing unit tests pass.

Rollback Plan:

Differential Revision: D82380307

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886
Approved by: https://github.com/zhxchen17
2025-09-20 01:24:37 +00:00

588 lines
21 KiB
Python

# Owner(s): ["module: dynamo"]
import importlib
import os
import sys
import tempfile
import unittest
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._functorch import config as functorch_config
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import (
HAS_CUDA_AND_TRITON,
HAS_XPU_AND_TRITON,
)
def compute_loss_helper(x):
return reduce_to_scalar_loss(x)
@functorch_config.patch("bundled_autograd_cache", True)
@torch._dynamo.config.patch({"strict_precompile": True})
@instantiate_parametrized_tests
class TestPackage(torch._inductor.test_case.TestCase):
def path(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return path
def setUp(self):
super().setUp()
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
def _save_and_reload(self, expected_backends, expected_dynamo):
"""
Serializes all artifacts, clears all caches, then reloads the serialized artifact
Simulates a new process.
Args:
expected_backends: Expected number of precompile_aot_autograd_artifacts
expected_dynamo: Expected number of precompile_dynamo_artifacts
"""
debug_info = PrecompileContext.save_to_dynamo_cache()
self.assertEqual(len(debug_info["dynamo"]), expected_dynamo)
self.assertEqual(len(debug_info["backends"]), expected_backends)
torch._dynamo.reset()
PrecompileContext.clear()
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
def test_nn_module(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10, device="cuda")
def forward(self, x):
return self.linear(x)
fn = MyModule()
package = CompilePackage(fn.forward)
compiled_fn = torch._dynamo.optimize("inductor", package=package)(fn)
x = torch.randn(10, 10, device="cuda")
compiled_fn(x)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_basic_fn(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x + 1
args = (
torch.randn(
3,
2,
device=device,
),
)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
expected = compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_lazy_backward(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x.sin() + x.cos()
args = (
torch.zeros(
3,
2,
device=device,
requires_grad=True,
),
)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
expected = compiled_fn(*args)
expected.sum().backward()
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_graph_break_bomb(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend=backend, package=package, guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(torch.tensor(N), 0, N - 1)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamic_shape(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x + x.shape[0]
args = (torch.randn(3, 2, device=device),)
args1 = (torch.randn(5, 2, device=device),)
args2 = (torch.randn(7, 2, device=device),)
expected1 = fn(*args1)
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args1)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected1, compiled_fn(*args1))
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args2)
def test_file_change(self):
ctx = DiskDynamoStore()
def import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
mock_module_add_original = """
def add(x, y):
return x + y
"""
mock_module_add_modified = """
def add(x, y):
return x - y
"""
with tempfile.TemporaryDirectory() as tmp_dir:
mock_module_add_original_path = os.path.join(
tmp_dir, "mock_module_add_original.py"
)
mock_module_add_modified_path = os.path.join(
tmp_dir, "mock_module_add_modified.py"
)
with open(mock_module_add_original_path, "w") as f:
f.write(mock_module_add_original)
with open(mock_module_add_modified_path, "w") as f:
f.write(mock_module_add_modified)
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
def fn(x):
return module.add(x, 1)
args = (torch.randn(3, 2),)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_modified_path,
)
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
ctx.load_package(fn, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
ctx.load_package(fn, self.path())
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamo_cache_manual_load(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
package1 = CompilePackage(fn)
package2 = CompilePackage(fn2)
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
DynamoCache.save(package1)
DynamoCache.save(package2)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=2)
# These should exist because of populate_caches
package1 = DynamoCache.load_and_install_package(fn)
package2 = DynamoCache.load_and_install_package(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_serialize(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [fn(arg1), fn2(arg2)]
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
self.assertEqual(expected, result)
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=2)
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
# Should cause a recompile
expected2 = compiled_fn(arg2)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=1)
compiled_fn = torch.compile(fn)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn(arg1)
result2 = compiled_fn(arg2)
# Because of automatic dynamic, a third random shape should also not cause a recompile
arg3 = torch.randn(7, 2, device=device)
compiled_fn(arg3)
self.assertEqual(result1, expected1)
self.assertEqual(result2, expected2)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_graph_breaks(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=8, expected_dynamo=1)
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
# Should have same number of frames as on cold start
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_lazy_backward(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
arg2 = arg1.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
expected1.sum().backward()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(fn)
# Run it again, no recompile needed
with torch.compiler.set_stance("fail_on_recompile"):
expected2 = compiled_fn(arg2)
expected2.sum().backward()
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_call_function_from_resume(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
mod = torch.nn.Linear(2, 3, device=device)
def foo(x, mod):
pred = mod(x)
compute_loss_helper(pred).backward()
return None
args = (torch.randn(3, 2, device=device), mod)
compiled_fn = torch.compile(foo)
compiled_fn(*args)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(foo)
# Run it again, no recompile needed
with torch.compiler.set_stance("fail_on_recompile"):
compiled_fn(*args)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_code_with_generator(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def foo(set_of_x):
if not all(isinstance(s, torch.Tensor) for s in set_of_x):
raise TypeError(
f"Expected all elements of set_of_x to be tensors, got {set_of_x}"
)
return torch.cat(set_of_x, dim=0)
args = ([torch.randn(3, 2, device=device) for _ in range(3)],)
compiled_fn = torch.compile(foo)
compiled_fn(*args)
self._save_and_reload(expected_backends=1, expected_dynamo=1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()