[precompile] [easy] Refactor FxGraphCache to add cache_hit_post_compile function (#152839)

This PR refactors CompiledFxGraph by adding a new post_compile step that only runs on cache hit. This refactors a bunch of code in _lookup_graph to its own function so that we can use it in BundledAOTAutogradCacheEntry. No difference in behavior here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152839
Approved by: https://github.com/oulgen
ghstack dependencies: #152836
This commit is contained in:
James Wu 2025-05-05 10:16:54 -07:00 committed by PyTorch MergeBot
parent a8f727c439
commit f56bcd2408

View File

@ -1086,6 +1086,82 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
"""
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
@staticmethod
def cache_hit_post_compile(
graph: CompiledFxGraph,
cache_info: dict[str, Any],
constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
"""
Cache specific post compile steps that need to run if we find a graph in the cache
This includes putting bundled triton artifacts in the right place,
reloading the PyCodeCache artifact, etc.
These don't always happen (i.e. on a cache miss, so they are in a separate function from
CompiledFxGraph.post_compile)
"""
if bundle := graph._triton_bundle:
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
if (meta := triton_bundler_meta) is not None:
cache_info["triton_bundler_meta"] = str(meta)
CompileEventLogger.try_add_pt2_compile(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
CompileEventLogger.try_add_pt2_compile(
"AOTAutogradCache.inductor_load",
cached_kernel_names=meta.cached_kernel_names,
)
if len(meta.cached_kernel_names) > 0:
CompileEventLogger.try_(
CompileEventLogger.increment_toplevel, "num_triton_bundles"
)
try:
artifact_path = graph.after_deserialization(constants)
from .graph import GraphLowering
# This is used by tests to check the output for specific details.
if GraphLowering.save_output_code is not None:
GraphLowering.save_output_code(graph.source_code)
except OSError:
# Not expected, but in case the PyCodeCache entry is removed from
# underneath us, treat it as a cache miss and recompile.
return None, cache_info
inductor_meta = autotune_cache.inductor_meta_from_config()
code = graph.source_code
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
# Increment the cached metrics/counters by the amounts recorded when the FX
# graph was compiled for this cache entry. Pretending these counters
# were incremented normally is useful for testing with the cache enabled.
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
counters["inductor"] += graph.counter_deltas
output_code_log.debug("Output code: \n%s", code)
output_code_log.debug("Output code written to: %s", artifact_path)
# On cache hit, use artifact path as filename
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_graph_runnable",
"encoding": "string",
},
payload_fn=lambda: graph.runnable_graph_str,
)
trace_structured(
"inductor_post_grad_graph",
payload_fn=lambda: graph.inductor_post_grad_graph_str,
)
trace_structured(
"inductor_output_code",
lambda: {"filename": artifact_path},
payload_fn=lambda: code,
)
return graph, cache_info
@staticmethod
def _lookup_graph(
key: str,
@ -1136,40 +1212,6 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
CacheArtifactType.INDUCTOR, key, pickled_content
)
if bundle := graph._triton_bundle:
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
if (meta := triton_bundler_meta) is not None:
cache_info["triton_bundler_meta"] = str(meta)
CompileEventLogger.try_add_pt2_compile(
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
CompileEventLogger.try_add_pt2_compile(
"AOTAutogradCache.inductor_load",
cached_kernel_names=meta.cached_kernel_names,
)
if len(meta.cached_kernel_names) > 0:
CompileEventLogger.try_(
CompileEventLogger.increment_toplevel, "num_triton_bundles"
)
try:
artifact_path = graph.after_deserialization(constants)
from .graph import GraphLowering
# This is used by tests to check the output for specific details.
if GraphLowering.save_output_code is not None:
GraphLowering.save_output_code(graph.source_code)
except OSError:
# Not expected, but in case the PyCodeCache entry is removed from
# underneath us, treat it as a cache miss and recompile.
return None, cache_info
inductor_meta = autotune_cache.inductor_meta_from_config()
code = graph.source_code
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
# Now re-evaluate with the symints to add any guards to the current env.
if graph.guards_expr:
check = bool(evaluate_guards(graph.guards_expr, symints))
@ -1178,33 +1220,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
"fx graph cache key %s post-load guards: %s", key, shape_env.guards
)
# Increment the cached metrics/counters by the amounts recorded when the FX
# graph was compiled for this cache entry. Pretending these counters
# were incremented normally is useful for testing with the cache enabled.
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
counters["inductor"] += graph.counter_deltas
output_code_log.debug("Output code: \n%s", code)
output_code_log.debug("Output code written to: %s", artifact_path)
# On cache hit, use artifact path as filename
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_graph_runnable",
"encoding": "string",
},
payload_fn=lambda: graph.runnable_graph_str,
)
trace_structured(
"inductor_post_grad_graph",
payload_fn=lambda: graph.inductor_post_grad_graph_str,
)
trace_structured(
"inductor_output_code",
lambda: {"filename": artifact_path},
payload_fn=lambda: code,
)
return graph, cache_info
return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants)
@staticmethod
def _write_to_local_cache(key: str, content: bytes) -> None: