Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)"

This reverts commit 8e57cdb746.

Reverted https://github.com/pytorch/pytorch/pull/158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](https://github.com/pytorch/pytorch/pull/158048#issuecomment-3098746624))
This commit is contained in:
PyTorch MergeBot 2025-07-21 20:45:21 +00:00
parent b1a0c34dd3
commit bc379aebe2
4 changed files with 3 additions and 79 deletions

View File

@ -15,7 +15,6 @@ import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config from torch._functorch import config as functorch_config
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -429,39 +428,6 @@ def add(x, y):
self.assertEqual(expected, [result1, result2]) self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_autotune_cache(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, y):
return x.sin() + y
arg1 = torch.randn(3, 3, device=device)
arg2 = torch.randn(3, 3, device=device)
expected = fn(arg1, arg2).clone()
with PatchCaches():
compiled_fn1 = torch.compile(fn, mode="max-autotune")
result = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result)
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(
expected_backends=1, expected_dynamo=1, expected_autotune=1
)
compiled_fn1 = torch.compile(fn, mode="max-autotune")
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result1)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
@parametrize("device", ("cpu", "cuda", "xpu")) @parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True) @torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device): def test_automatic_dynamo_recompiles(self, device):

View File

@ -70,8 +70,7 @@ class PrecompileContext(CacheArtifactManager):
The following artifact types are supported by PrecompileContext: The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact - BundledAOTAutogradCacheArtifact
- DynamoCodeStateArtifact - CodeStateArtifact (from torch._dynamo.package once available)
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
""" """
# Protected by the compile_lock # Protected by the compile_lock
@ -150,12 +149,8 @@ class PrecompileContext(CacheArtifactManager):
artifacts_by_key = {} artifacts_by_key = {}
cache_info = CacheInfo() cache_info = CacheInfo()
for artifact in chain(*artifacts.values()): for artifact in chain(*artifacts.values()):
if artifact.type() == "autotune":
# Populate autotune cache artifacts
artifact.populate_cache()
else:
artifacts_by_key[artifact.key] = artifact
cache_info.add(artifact) cache_info.add(artifact)
artifacts_by_key[artifact.key] = artifact
from torch._dynamo.package import _BackendId, DynamoCache from torch._dynamo.package import _BackendId, DynamoCache

View File

@ -909,37 +909,10 @@ def _compile_fx_inner(
else: else:
log.debug("Failed to generate FX cache key") log.debug("Failed to generate FX cache key")
if torch._functorch.config.bundled_autograd_cache:
assert mb_compiled_graph is None
assert cache_info is None
# When using bundled autograd cache, we still want
# to use the TritonBundler, but we don't want to save
# the results here. The results will get saved directly
# to AOTAutogradCache.
TritonBundler.begin_compile()
try:
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert mb_compiled_graph is not None
(
triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
mb_compiled_graph.set_triton_bundle(triton_bundle)
except (ShortenTraceback, SkipFrame):
raise
except Exception as e:
raise InductorError(e, currentframe()).with_traceback(
e.__traceback__
) from None
finally:
TritonBundler.end_compile()
# CACHE BYPASS: Compile the graph, don't save it to the cache # CACHE BYPASS: Compile the graph, don't save it to the cache
# (this can happen either because cache was disabled, or we # (this can happen either because cache was disabled, or we
# determined the input is uncacheable) # determined the input is uncacheable)
elif cache_info is None or cache_info["cache_state"] == "bypass": if cache_info is None or cache_info["cache_state"] == "bypass":
assert mb_compiled_graph is None assert mb_compiled_graph is None
log.debug( log.debug(
"FX cache bypass reason: %s", "FX cache bypass reason: %s",

View File

@ -35,7 +35,6 @@ from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import override from typing_extensions import override
import torch import torch
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch.compiler._cache import ( from torch.compiler._cache import (
CacheArtifact, CacheArtifact,
@ -126,7 +125,6 @@ class AutotuneCache:
) -> Optional[AutotuneCache]: ) -> Optional[AutotuneCache]:
cache = AutotuneCache(configs_hash) cache = AutotuneCache(configs_hash)
key = AutotuneCache._prepare_key(filename) key = AutotuneCache._prepare_key(filename)
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
cache._setup_remote_autotune_cache(inductor_meta, key) cache._setup_remote_autotune_cache(inductor_meta, key)
if cache.local_cache or cache.remote_cache: if cache.local_cache or cache.remote_cache:
@ -302,10 +300,6 @@ class AutotuneCache:
CacheArtifactManager.record_artifact( CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data AutotuneCacheArtifact.type(), autotune_artifact_key, data
) )
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data
)
if log.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
type_str = "coordesc" if found_by_coordesc else "heuristic" type_str = "coordesc" if found_by_coordesc else "heuristic"
@ -631,10 +625,6 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
CacheArtifactManager.record_artifact( CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result AutotuneCacheArtifact.type(), autotune_artifact_key, result
) )
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result
)
return result return result
@override @override