Revert "Recheck Autotune cache on Precompile serialization to prune compilation results (#158656)"

This reverts commit 664005662a.

Reverted https://github.com/pytorch/pytorch/pull/158656 on behalf of https://github.com/seemethere due to failing internal tests, see D80486843 ([comment](https://github.com/pytorch/pytorch/pull/158656#issuecomment-3201491561))
This commit is contained in:
PyTorch MergeBot 2025-08-19 16:53:20 +00:00
parent fecc5f6001
commit eddaaa6c2a
7 changed files with 29 additions and 100 deletions

View File

@ -16,7 +16,7 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.testing import reduce_to_scalar_loss
from torch._functorch import config as functorch_config from torch._functorch import config as functorch_config
from torch._inductor.mock_cache import global_stats, PatchCaches from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -452,33 +452,27 @@ def add(x, y):
def fn(x, y): def fn(x, y):
return x.sin() + y return x.sin() + y
arg1 = torch.randn(32, 32, device=device) arg1 = torch.randn(3, 3, device=device)
arg2 = torch.randn(32, 32, device=device) arg2 = torch.randn(3, 3, device=device)
expected = fn(arg1, arg2).clone() expected = fn(arg1, arg2).clone()
with PatchCaches(): with PatchCaches():
compiled_fn1 = torch.compile(fn, mode="max-autotune") compiled_fn1 = torch.compile(fn, mode="max-autotune")
result = compiled_fn1(arg1, arg2).clone() result = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result) self.assertEqual(expected, result)
self.assertEqual(global_stats.autotune_local.num_get_miss, 1) self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
DynamoCache.clear() DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload( self._save_and_reload(
expected_backends=1, expected_dynamo=1, expected_autotune=1 expected_backends=1, expected_dynamo=1, expected_autotune=1
) )
# During save, we check the autotune cache another time, and now it should hit
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
compiled_fn1 = torch.compile(fn, mode="max-autotune") compiled_fn1 = torch.compile(fn, mode="max-autotune")
with torch.compiler.set_stance("fail_on_recompile"): with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1, arg2).clone() result1 = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result1) self.assertEqual(expected, result1)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
# No new hits or misses self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
# Unfortunately, we don't *actually* know how many puts there will be, because
# it's possible the best autotune config was found by coordesc.
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
self.assertEqual(global_stats.autotune_local.num_get_miss, 1)
@parametrize("device", ("cpu", "cuda", "xpu")) @parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True) @torch._dynamo.config.patch(caching_precompile=True)

View File

@ -169,16 +169,7 @@ class PrecompileContext(CacheArtifactManager):
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
""" """
for artifact in cls._new_cache_artifacts_by_key.values(): for artifact in cls._new_cache_artifacts_by_key.values():
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheEntry,
)
if isinstance(artifact, EditablePrecompileCacheArtifact): if isinstance(artifact, EditablePrecompileCacheArtifact):
if isinstance(artifact.content, BundledAOTAutogradCacheEntry):
# BundledAOTAutogradCacheEntries should update their autotune results
artifact.edit_contents(
BundledAOTAutogradCacheEntry.update_autotune_results
)
artifact = artifact.real_encode() artifact = artifact.real_encode()
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear() cls._new_cache_artifacts_by_key.clear()
@ -204,15 +195,6 @@ class PrecompileContext(CacheArtifactManager):
""" """
result = cls._new_cache_artifacts_by_key.get(key, None) result = cls._new_cache_artifacts_by_key.get(key, None)
if isinstance(result, EditablePrecompileCacheArtifact): if isinstance(result, EditablePrecompileCacheArtifact):
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheEntry,
)
if isinstance(result.content, BundledAOTAutogradCacheEntry):
# BundledAOTAutogradCacheEntries should update their autotune results
result.edit_contents(
BundledAOTAutogradCacheEntry.update_autotune_results
)
result = result.real_encode() result = result.real_encode()
return result return result

View File

@ -535,32 +535,6 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
result: CompiledFxGraph result: CompiledFxGraph
def recheck_autotune_results(self) -> None:
"""
Run during PrecompileContext.serialize(). We recheck the autotune cache
again before saving results, to see if autotuning has completed for our generated
triton kernels. If so, it edits the statically compiled triton kernel so that only
the best config is preserved.
"""
triton_bundle = self.result._triton_bundle
if triton_bundle is None:
return
static_autotuners = triton_bundle.static_autotuners
for autotuner in static_autotuners:
from torch._inductor.codecache import _load_triton_kernel_from_source
reload_kernel_from_src = functools.partial(
_load_triton_kernel_from_source,
autotuner.kernel_name,
autotuner.source_code,
)
autotuner.kernel.recheck_autotune_cache(
reload_kernel_from_src,
)
# Clear any extra state created by this check
autotuner.kernel.prepare_for_pickle()
autotuner.kernel.prepare_for_caching()
def pre_save(self) -> None: def pre_save(self) -> None:
disk_compiled_graph = copy(self.result) disk_compiled_graph = copy(self.result)
disk_compiled_graph.prepare_for_serialization() disk_compiled_graph.prepare_for_serialization()
@ -1024,18 +998,6 @@ class BundledAOTAutogradCacheEntry(
of relying on cache keys from FxGraphCache of relying on cache keys from FxGraphCache
""" """
@staticmethod
def update_autotune_results(
entry: BundledAOTAutogradCacheEntry,
) -> BundledAOTAutogradCacheEntry:
"""
Update the autotune results in the cache entry.
"""
entry.compiled_fw.recheck_autotune_results()
if entry.compiled_bw is not None:
entry.compiled_bw.recheck_autotune_results()
return entry
@contextlib.contextmanager @contextlib.contextmanager
def sanitize_gm_for_cache(gm: torch.fx.GraphModule): def sanitize_gm_for_cache(gm: torch.fx.GraphModule):

View File

@ -401,9 +401,11 @@ class AsyncCompile:
if (future := CompiledTritonKernels.get(source_code)) is not None: if (future := CompiledTritonKernels.get(source_code)) is not None:
counters["inductor"]["async_compile_cache_hit"] += 1 counters["inductor"]["async_compile_cache_hit"] += 1
# Set reload_kernel_from_src properly based on source_code
if isinstance(future, StaticAutotunerFuture): if isinstance(future, StaticAutotunerFuture):
# Remove the future now that we've cache hit # Remove the future now that we've cache hit
CompiledTritonKernels.remove_future(source_code) CompiledTritonKernels.remove_future(source_code)
future.reload_kernel_from_src = reload_kernel_in_parent
if is_parallel: if is_parallel:
return future return future
else: else:
@ -457,7 +459,7 @@ class AsyncCompile:
kernel.precompile( kernel.precompile(
warm_cache_only=False, warm_cache_only=False,
reload_kernel=reload_kernel_in_parent, reload_kernel=reload_kernel_in_parent,
source_code=source_code, static_triton_bundle_key=CompiledTritonKernels.key(source_code),
) )
info = kernel.autotune_cache_info or {} info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us info["compile_time_us"] = elapsed_us
@ -486,7 +488,7 @@ class AsyncCompile:
kernel.set_compile_info(compile_id, is_backward) kernel.set_compile_info(compile_id, is_backward)
kernel.precompile( kernel.precompile(
warm_cache_only=False, warm_cache_only=False,
source_code=source_code, static_triton_bundle_key=CompiledTritonKernels.key(source_code),
) )
elapsed_us = (time_ns() - start_ns) // 1000 elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n( get_metrics_context().add_top_n(

View File

@ -4213,28 +4213,24 @@ class StaticAutotunerFuture(CodeCacheFuture):
A statically launchable CachingAutotuner, loaded from TritonBundler A statically launchable CachingAutotuner, loaded from TritonBundler
""" """
def __init__( def __init__(self, static_autotuner: CachingAutotuner) -> None:
self, static_autotuner: CachingAutotuner, kernel_name: str, source_code: str
) -> None:
# Pickled version of CachingAutotuner # Pickled version of CachingAutotuner
self.static_autotuner = static_autotuner self.static_autotuner = static_autotuner
self.kernel_name = kernel_name # This needs to be set in AsyncCompile.triton, in case
# The python source code of the kernel is relatively small and stored by StaticallyLaunchedAutotuner. # we need to reload the CachingAutotuner from its source code
# We do not store the compiled cuda code here as it's very large, # We don't store the source code on the CachingAutotuner itself
# it's stored via the regular TritonBundler # since it can be very large.
self.source_code = source_code self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
def result(self) -> CachingAutotuner: def result(self) -> CachingAutotuner:
assert self.reload_kernel_from_src is not None
with dynamo_timed("StaticAutotunerFuture.warm_precompile"): with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
reload_kernel_from_src = functools.partial(
_load_triton_kernel_from_source, self.kernel_name, self.source_code
)
self.static_autotuner.recheck_autotune_cache( self.static_autotuner.recheck_autotune_cache(
reload_kernel_from_src=reload_kernel_from_src reload_kernel_from_src=self.reload_kernel_from_src
) )
self.static_autotuner.precompile( # type: ignore[union-attr] self.static_autotuner.precompile( # type: ignore[union-attr]
warm_cache_only=False, warm_cache_only=False,
reload_kernel=reload_kernel_from_src, reload_kernel=self.reload_kernel_from_src,
source_code=None, # no need to save again static_triton_bundle_key=None, # no need to save again
) )
return self.static_autotuner return self.static_autotuner

View File

@ -386,14 +386,13 @@ class CachingAutotuner(KernelInterface):
assert self.is_statically_launchable() assert self.is_statically_launchable()
configs = [result.config for result in self.compile_results] configs = [result.config for result in self.compile_results]
if len(configs) <= 1:
return
(cached_configs, _, autotune_cache_info) = check_autotune_cache( (cached_configs, _, autotune_cache_info) = check_autotune_cache(
configs, self.filename, self.inductor_meta configs, self.filename, self.inductor_meta
) )
self.autotune_cache_info = autotune_cache_info self.autotune_cache_info = autotune_cache_info
# I.e. there was an autotune cache hit # I.e. there was an autotune cache hit
if len(cached_configs) == 1: if len(cached_configs) == 1 and len(configs) > 1:
best_config = cached_configs[0] best_config = cached_configs[0]
# Grab the best compiled config, if it's in the list of available ones # Grab the best compiled config, if it's in the list of available ones
best_config_hash = triton_config_to_hashable(best_config) best_config_hash = triton_config_to_hashable(best_config)
@ -422,7 +421,7 @@ class CachingAutotuner(KernelInterface):
self, self,
warm_cache_only=False, warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
source_code: Optional[str] = None, # Used for static_triton_bundle_key static_triton_bundle_key: Optional[str] = None,
): ):
if warm_cache_only: if warm_cache_only:
self._precompile_worker() self._precompile_worker()
@ -435,9 +434,8 @@ class CachingAutotuner(KernelInterface):
if reload_kernel is not None: if reload_kernel is not None:
self._reload_kernel = reload_kernel self._reload_kernel = reload_kernel
self._precompile_worker() self._precompile_worker()
if static_triton_bundle_key is not None and self.is_statically_launchable():
if source_code is not None and self.is_statically_launchable(): TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
TritonBundler.put_static_autotuner(source_code, self)
self._make_launchers() self._make_launchers()
self._dynamic_scale_rblock() self._dynamic_scale_rblock()

View File

@ -53,11 +53,7 @@ class StaticallyLaunchedAutotuner:
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry. Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
""" """
# We store the kernel's python source code here which we use for two things: cache_key: str
# First, to calculate a cache key for CompiledTritonKernels
# Second, in case we need to reload the kernel on load,
# we can do so by reading the source code from the cache entry.
source_code: str
kernel_name: str kernel_name: str
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821 kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
@ -168,7 +164,7 @@ class TritonBundler:
) )
@classmethod @classmethod
def put_static_autotuner(cls, source_code: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
from torch._inductor import config from torch._inductor import config
assert config.use_static_cuda_launcher assert config.use_static_cuda_launcher
@ -182,7 +178,7 @@ class TritonBundler:
entries.append( entries.append(
StaticallyLaunchedAutotuner( StaticallyLaunchedAutotuner(
source_code, key,
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"), new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
new_kernel, new_kernel,
) )
@ -244,9 +240,8 @@ class TritonBundler:
# kernels that are not statically launchable (i.e. cache miss) # kernels that are not statically launchable (i.e. cache miss)
# can launch a worker without waiting on the blocking step of # can launch a worker without waiting on the blocking step of
# StaticAutotunerFuture.result(). # StaticAutotunerFuture.result().
cache_key = CompiledTritonKernels.key(result.source_code) CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
CompiledTritonKernels._cache[cache_key] = StaticAutotunerFuture( result.kernel
result.kernel, result.kernel_name, result.source_code
) )
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1 counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
kernel_names.append(result.kernel_name) kernel_names.append(result.kernel_name)