From bc379aebe2e69d306d1b05938a9e86c80f6b98cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 20:45:21 +0000 Subject: [PATCH] Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)" This reverts commit 8e57cdb746b4ab28865fdf01532f87b0d21700e9. Reverted https://github.com/pytorch/pytorch/pull/158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](https://github.com/pytorch/pytorch/pull/158048#issuecomment-3098746624)) --- test/dynamo/test_package.py | 34 ----------------------- torch/_dynamo/precompile_context.py | 9 ++---- torch/_inductor/compile_fx.py | 29 +------------------ torch/_inductor/runtime/autotune_cache.py | 10 ------- 4 files changed, 3 insertions(+), 79 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 51f6ca91136..31600077740 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -15,7 +15,6 @@ import torch.utils.cpp_extension from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._functorch import config as functorch_config -from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -429,39 +428,6 @@ def add(x, y): self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) - @parametrize("device", ("cuda", "xpu")) - @torch._dynamo.config.patch(caching_precompile=True) - def test_automatic_dynamo_autotune_cache(self, device): - if device == "cuda" and not HAS_CUDA: - raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: - raise unittest.SkipTest("Requires XPU/Triton") - - def fn(x, y): - return x.sin() + y - - arg1 = torch.randn(3, 3, device=device) - arg2 = torch.randn(3, 3, device=device) - expected = fn(arg1, arg2).clone() - - with PatchCaches(): - compiled_fn1 = torch.compile(fn, mode="max-autotune") - result = compiled_fn1(arg1, arg2).clone() - self.assertEqual(expected, result) - self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) - DynamoCache.clear() - - total_frames = torch._dynamo.convert_frame.FRAME_COUNTER - self._save_and_reload( - expected_backends=1, expected_dynamo=1, expected_autotune=1 - ) - compiled_fn1 = torch.compile(fn, mode="max-autotune") - with torch.compiler.set_stance("fail_on_recompile"): - result1 = compiled_fn1(arg1, arg2).clone() - self.assertEqual(expected, result1) - self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) - self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) - @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 040f54ce70d..6bb42bb34bc 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -70,8 +70,7 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact - - DynamoCodeStateArtifact - - AutotuneCacheArtifact (regular autotune results, same as Megacache) + - CodeStateArtifact (from torch._dynamo.package once available) """ # Protected by the compile_lock @@ -150,12 +149,8 @@ class PrecompileContext(CacheArtifactManager): artifacts_by_key = {} cache_info = CacheInfo() for artifact in chain(*artifacts.values()): - if artifact.type() == "autotune": - # Populate autotune cache artifacts - artifact.populate_cache() - else: - artifacts_by_key[artifact.key] = artifact cache_info.add(artifact) + artifacts_by_key[artifact.key] = artifact from torch._dynamo.package import _BackendId, DynamoCache diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 8e712a28a3b..95c12d12c78 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -909,37 +909,10 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") - if torch._functorch.config.bundled_autograd_cache: - assert mb_compiled_graph is None - assert cache_info is None - # When using bundled autograd cache, we still want - # to use the TritonBundler, but we don't want to save - # the results here. The results will get saved directly - # to AOTAutogradCache. - TritonBundler.begin_compile() - try: - mb_compiled_graph = fx_codegen_and_compile( - gm, example_inputs, inputs_to_check, **graph_kwargs - ) - assert mb_compiled_graph is not None - ( - triton_bundle, - triton_bundler_meta, - ) = TritonBundler.collect() - mb_compiled_graph.set_triton_bundle(triton_bundle) - except (ShortenTraceback, SkipFrame): - raise - except Exception as e: - raise InductorError(e, currentframe()).with_traceback( - e.__traceback__ - ) from None - finally: - TritonBundler.end_compile() - # CACHE BYPASS: Compile the graph, don't save it to the cache # (this can happen either because cache was disabled, or we # determined the input is uncacheable) - elif cache_info is None or cache_info["cache_state"] == "bypass": + if cache_info is None or cache_info["cache_state"] == "bypass": assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 88b9c80c771..01d038aab8e 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,7 +35,6 @@ from typing import Any, Optional, TYPE_CHECKING from typing_extensions import override import torch -from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -126,7 +125,6 @@ class AutotuneCache: ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) - cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -302,10 +300,6 @@ class AutotuneCache: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) - if torch._dynamo.config.caching_precompile: - PrecompileContext.record_artifact( - AutotuneCacheArtifact.type(), autotune_artifact_key, data - ) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -631,10 +625,6 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]): CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) - if torch._dynamo.config.caching_precompile: - PrecompileContext.record_artifact( - AutotuneCacheArtifact.type(), autotune_artifact_key, result - ) return result @override