# Owner(s): ["module: dynamo"] import importlib import os import sys import tempfile import unittest import torch import torch._dynamo.testing import torch._inductor.config import torch._inductor.test_case import torch.onnx.operators import torch.utils.cpp_extension from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.testing import reduce_to_scalar_loss from torch._functorch import config as functorch_config from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) from torch.testing._internal.inductor_utils import ( HAS_CUDA_AND_TRITON, HAS_XPU_AND_TRITON, ) def compute_loss_helper(x): return reduce_to_scalar_loss(x) @functorch_config.patch("bundled_autograd_cache", True) @torch._dynamo.config.patch({"strict_precompile": True}) @instantiate_parametrized_tests class TestPackage(torch._inductor.test_case.TestCase): def path(self): path = os.path.join(cache_dir(), f"package_{self.id()}") os.makedirs(path, exist_ok=True) return path def setUp(self): super().setUp() torch._dynamo.reset() torch._dynamo.utils.counters.clear() DynamoCache.clear() PrecompileContext.clear() def _save_and_reload(self, expected_backends, expected_dynamo): """ Serializes all artifacts, clears all caches, then reloads the serialized artifact Simulates a new process. Args: expected_backends: Expected number of precompile_aot_autograd_artifacts expected_dynamo: Expected number of precompile_dynamo_artifacts """ debug_info = PrecompileContext.save_to_dynamo_cache() self.assertEqual(len(debug_info["dynamo"]), expected_dynamo) self.assertEqual(len(debug_info["backends"]), expected_backends) torch._dynamo.reset() PrecompileContext.clear() @unittest.expectedFailure # FUNCTION_MATCH guard not serializable today def test_nn_module(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10, device="cuda") def forward(self, x): return self.linear(x) fn = MyModule() package = CompilePackage(fn.forward) compiled_fn = torch._dynamo.optimize("inductor", package=package)(fn) x = torch.randn(10, 10, device="cuda") compiled_fn(x) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_basic_fn(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() def fn(x): return x + 1 args = ( torch.randn( 3, 2, device=device, ), ) # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize(backend, package=package)(fn) expected = compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) package.install(backends) self.assertEqual(expected, compiled_fn(*args)) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_lazy_backward(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() def fn(x): return x.sin() + x.cos() args = ( torch.zeros( 3, 2, device=device, requires_grad=True, ), ) # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize(backend, package=package)(fn) expected = compiled_fn(*args) expected.sum().backward() if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) package.install(backends) self.assertEqual(expected, compiled_fn(*args)) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_graph_break_bomb(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() def fn(x, l, r): if l > r: return x.sum() mid = (l + r) // 2 if x.sum() == mid: return x.sum() elif x.sum() < mid: return fn(x, l, mid) else: return fn(x, mid + 1, r) def guard_filter_fn(guards): return [ guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") for guard in guards ] # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize( backend=backend, package=package, guard_filter_fn=guard_filter_fn )(fn) N = 10 args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)] for args in args_list: compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): for args in args_list: with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize( backend="eager", package=package, guard_filter_fn=guard_filter_fn )(fn) package.install(backends) for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(torch.tensor(N), 0, N - 1) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamic_shape(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() def fn(x): return x + x.shape[0] args = (torch.randn(3, 2, device=device),) args1 = (torch.randn(5, 2, device=device),) args2 = (torch.randn(7, 2, device=device),) expected1 = fn(*args1) torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5) # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn) compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args1) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) package.install(backends) self.assertEqual(expected1, compiled_fn(*args1)) with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args2) def test_file_change(self): ctx = DiskDynamoStore() def import_from_path(module_name, file_path): spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module mock_module_add_original = """ def add(x, y): return x + y """ mock_module_add_modified = """ def add(x, y): return x - y """ with tempfile.TemporaryDirectory() as tmp_dir: mock_module_add_original_path = os.path.join( tmp_dir, "mock_module_add_original.py" ) mock_module_add_modified_path = os.path.join( tmp_dir, "mock_module_add_modified.py" ) with open(mock_module_add_original_path, "w") as f: f.write(mock_module_add_original) with open(mock_module_add_modified_path, "w") as f: f.write(mock_module_add_modified) module = import_from_path( "torch.test_package_helper", mock_module_add_original_path, ) def fn(x): return module.add(x, 1) args = (torch.randn(3, 2),) def guard_filter_fn(guards): return [ guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH", "MODULE_MATCH") for guard in guards ] # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize( backend="eager", package=package, guard_filter_fn=guard_filter_fn )(fn) compiled_fn(*args) for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) module = import_from_path( "torch.test_package_helper", mock_module_add_modified_path, ) with self.assertRaisesRegex(RuntimeError, "Source code changes detected"): ctx.load_package(fn, self.path()) module = import_from_path( "torch.test_package_helper", mock_module_add_original_path, ) ctx.load_package(fn, self.path()) @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamo_cache_manual_load(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): return x.sin() + x.cos() def fn2(x): return x.cos() + x package1 = CompilePackage(fn) package2 = CompilePackage(fn2) compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn) compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2) arg1 = torch.randn(3, 2, device=device) arg2 = torch.randn(5, 2, device=device) expected = [compiled_fn1(arg1), compiled_fn2(arg2)] DynamoCache.save(package1) DynamoCache.save(package2) total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) # These should exist because of populate_caches package1 = DynamoCache.load_and_install_package(fn) package2 = DynamoCache.load_and_install_package(fn2) with torch.compiler.set_stance("fail_on_recompile"): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_serialize(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): return x.sin() + x.cos() def fn2(x): return x.cos() + x arg1 = torch.randn(3, 2, device=device) arg2 = torch.randn(5, 2, device=device) expected = [fn(arg1), fn2(arg2)] compiled_fn1 = torch.compile(fn) compiled_fn2 = torch.compile(fn2) result = [compiled_fn1(arg1), compiled_fn2(arg2)] self.assertEqual(expected, result) DynamoCache.clear() total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) compiled_fn1 = torch.compile(fn) compiled_fn2 = torch.compile(fn2) with torch.compiler.set_stance("fail_on_recompile"): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): return x.sin() + x.cos() arg1 = torch.randn(3, 2, device=device) arg2 = torch.randn(5, 2, device=device) compiled_fn = torch.compile(fn) expected1 = compiled_fn(arg1) # Should cause a recompile expected2 = compiled_fn(arg2) total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=1) compiled_fn = torch.compile(fn) with torch.compiler.set_stance("fail_on_recompile"): result1 = compiled_fn(arg1) result2 = compiled_fn(arg2) # Because of automatic dynamic, a third random shape should also not cause a recompile arg3 = torch.randn(7, 2, device=device) compiled_fn(arg3) self.assertEqual(result1, expected1) self.assertEqual(result2, expected2) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_graph_breaks(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x, l, r): if l > r: return x.sum() mid = (l + r) // 2 if x.sum() == mid: return x.sum() elif x.sum() < mid: return fn(x, l, mid) else: return fn(x, mid + 1, r) def guard_filter_fn(guards): return [ guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") for guard in guards ] # Saving compiled_fn = torch._dynamo.optimize( backend="inductor", guard_filter_fn=guard_filter_fn )(fn) N = 10 args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)] for args in args_list: compiled_fn(*args) total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=8, expected_dynamo=1) compiled_fn = torch._dynamo.optimize( backend="inductor", guard_filter_fn=guard_filter_fn )(fn) with torch.compiler.set_stance("fail_on_recompile"): for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) # Should have same number of frames as on cold start self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_lazy_backward(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): return x.sin() + x.cos() arg1 = torch.randn(3, 2, device=device, requires_grad=True) arg2 = arg1.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn) expected1 = compiled_fn(arg1) expected1.sum().backward() total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=1, expected_dynamo=1) compiled_fn = torch.compile(fn) # Run it again, no recompile needed with torch.compiler.set_stance("fail_on_recompile"): expected2 = compiled_fn(arg2) expected2.sum().backward() self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_graph_break_partial_backend(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): y = x.sin() torch._dynamo.graph_break() return x.sin() + y arg1 = torch.randn(3, 2, device=device, requires_grad=True) arg2 = arg1.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn) expected1 = compiled_fn(arg1) expected1.sum().backward() total_frames = torch._dynamo.convert_frame.FRAME_COUNTER # Remove backends related to resume functions dynamo_entry = next(iter(PrecompileContext._dynamo_cache_entries.values())) for code in dynamo_entry.codes: module = sys.modules[code.python_module] if code.install_to_global: # Clear the fn_names from global scope, to simulate a new environment for fn_name in code.function_names: module.__dict__.pop(fn_name) for fn_name in code.function_names: if "resume" in fn_name: self.assertEqual(len(code.backend_ids), 1) # delete the fn from the global scope to simulate a new backend = code.backend_ids[0] # Delete the backend associated with the resume function del PrecompileContext._backend_artifacts_by_key[backend] self._save_and_reload(expected_backends=1, expected_dynamo=1) compiled_fn = torch.compile(fn) # Run it again. There will be a recompile because one of the backends is deleted, but it should # still work. expected2 = compiled_fn(arg2) expected2.sum().backward() self.assertEqual(expected1, expected2) # One recompile on a new frame, so total_frames should increase by 1 self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames + 1) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_call_function_from_resume(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") mod = torch.nn.Linear(2, 3, device=device) def foo(x, mod): pred = mod(x) compute_loss_helper(pred).backward() return None args = (torch.randn(3, 2, device=device), mod) compiled_fn = torch.compile(foo) compiled_fn(*args) total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=1, expected_dynamo=1) compiled_fn = torch.compile(foo) # Run it again, no recompile needed with torch.compiler.set_stance("fail_on_recompile"): compiled_fn(*args) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_code_with_generator(self, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def foo(set_of_x): if not all(isinstance(s, torch.Tensor) for s in set_of_x): raise TypeError( f"Expected all elements of set_of_x to be tensors, got {set_of_x}" ) return torch.cat(set_of_x, dim=0) args = ([torch.randn(3, 2, device=device) for _ in range(3)],) compiled_fn = torch.compile(foo) compiled_fn(*args) self._save_and_reload(expected_backends=1, expected_dynamo=1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()