mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] [easy] Refactor FxGraphCache to add cache_hit_post_compile function (#152839)
This PR refactors CompiledFxGraph by adding a new post_compile step that only runs on cache hit. This refactors a bunch of code in _lookup_graph to its own function so that we can use it in BundledAOTAutogradCacheEntry. No difference in behavior here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152839 Approved by: https://github.com/oulgen ghstack dependencies: #152836
This commit is contained in:
parent
a8f727c439
commit
f56bcd2408
|
|
@ -1086,6 +1086,82 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||
"""
|
||||
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
||||
|
||||
@staticmethod
|
||||
def cache_hit_post_compile(
|
||||
graph: CompiledFxGraph,
|
||||
cache_info: dict[str, Any],
|
||||
constants: CompiledFxGraphConstants,
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
"""
|
||||
Cache specific post compile steps that need to run if we find a graph in the cache
|
||||
This includes putting bundled triton artifacts in the right place,
|
||||
reloading the PyCodeCache artifact, etc.
|
||||
|
||||
These don't always happen (i.e. on a cache miss, so they are in a separate function from
|
||||
CompiledFxGraph.post_compile)
|
||||
"""
|
||||
if bundle := graph._triton_bundle:
|
||||
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
|
||||
if (meta := triton_bundler_meta) is not None:
|
||||
cache_info["triton_bundler_meta"] = str(meta)
|
||||
CompileEventLogger.try_add_pt2_compile(
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
CompileEventLogger.try_add_pt2_compile(
|
||||
"AOTAutogradCache.inductor_load",
|
||||
cached_kernel_names=meta.cached_kernel_names,
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
CompileEventLogger.try_(
|
||||
CompileEventLogger.increment_toplevel, "num_triton_bundles"
|
||||
)
|
||||
|
||||
try:
|
||||
artifact_path = graph.after_deserialization(constants)
|
||||
|
||||
from .graph import GraphLowering
|
||||
|
||||
# This is used by tests to check the output for specific details.
|
||||
if GraphLowering.save_output_code is not None:
|
||||
GraphLowering.save_output_code(graph.source_code)
|
||||
|
||||
except OSError:
|
||||
# Not expected, but in case the PyCodeCache entry is removed from
|
||||
# underneath us, treat it as a cache miss and recompile.
|
||||
return None, cache_info
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
code = graph.source_code
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
|
||||
# Increment the cached metrics/counters by the amounts recorded when the FX
|
||||
# graph was compiled for this cache entry. Pretending these counters
|
||||
# were incremented normally is useful for testing with the cache enabled.
|
||||
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
|
||||
counters["inductor"] += graph.counter_deltas
|
||||
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
output_code_log.debug("Output code written to: %s", artifact_path)
|
||||
# On cache hit, use artifact path as filename
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "fx_graph_runnable",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: graph.runnable_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_post_grad_graph",
|
||||
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_output_code",
|
||||
lambda: {"filename": artifact_path},
|
||||
payload_fn=lambda: code,
|
||||
)
|
||||
return graph, cache_info
|
||||
|
||||
@staticmethod
|
||||
def _lookup_graph(
|
||||
key: str,
|
||||
|
|
@ -1136,40 +1212,6 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||
CacheArtifactType.INDUCTOR, key, pickled_content
|
||||
)
|
||||
|
||||
if bundle := graph._triton_bundle:
|
||||
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
|
||||
if (meta := triton_bundler_meta) is not None:
|
||||
cache_info["triton_bundler_meta"] = str(meta)
|
||||
CompileEventLogger.try_add_pt2_compile(
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
CompileEventLogger.try_add_pt2_compile(
|
||||
"AOTAutogradCache.inductor_load",
|
||||
cached_kernel_names=meta.cached_kernel_names,
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
CompileEventLogger.try_(
|
||||
CompileEventLogger.increment_toplevel, "num_triton_bundles"
|
||||
)
|
||||
|
||||
try:
|
||||
artifact_path = graph.after_deserialization(constants)
|
||||
|
||||
from .graph import GraphLowering
|
||||
|
||||
# This is used by tests to check the output for specific details.
|
||||
if GraphLowering.save_output_code is not None:
|
||||
GraphLowering.save_output_code(graph.source_code)
|
||||
|
||||
except OSError:
|
||||
# Not expected, but in case the PyCodeCache entry is removed from
|
||||
# underneath us, treat it as a cache miss and recompile.
|
||||
return None, cache_info
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
code = graph.source_code
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
|
||||
# Now re-evaluate with the symints to add any guards to the current env.
|
||||
if graph.guards_expr:
|
||||
check = bool(evaluate_guards(graph.guards_expr, symints))
|
||||
|
|
@ -1178,33 +1220,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||
"fx graph cache key %s post-load guards: %s", key, shape_env.guards
|
||||
)
|
||||
|
||||
# Increment the cached metrics/counters by the amounts recorded when the FX
|
||||
# graph was compiled for this cache entry. Pretending these counters
|
||||
# were incremented normally is useful for testing with the cache enabled.
|
||||
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
|
||||
counters["inductor"] += graph.counter_deltas
|
||||
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
output_code_log.debug("Output code written to: %s", artifact_path)
|
||||
# On cache hit, use artifact path as filename
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "fx_graph_runnable",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: graph.runnable_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_post_grad_graph",
|
||||
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_output_code",
|
||||
lambda: {"filename": artifact_path},
|
||||
payload_fn=lambda: code,
|
||||
)
|
||||
return graph, cache_info
|
||||
return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants)
|
||||
|
||||
@staticmethod
|
||||
def _write_to_local_cache(key: str, content: bytes) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user