mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] [easy] Refactor FxGraphCache to add cache_hit_post_compile function (#152839)
This PR refactors CompiledFxGraph by adding a new post_compile step that only runs on cache hit. This refactors a bunch of code in _lookup_graph to its own function so that we can use it in BundledAOTAutogradCacheEntry. No difference in behavior here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152839 Approved by: https://github.com/oulgen ghstack dependencies: #152836
This commit is contained in:
parent
a8f727c439
commit
f56bcd2408
|
|
@ -1086,6 +1086,82 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||||
"""
|
"""
|
||||||
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cache_hit_post_compile(
|
||||||
|
graph: CompiledFxGraph,
|
||||||
|
cache_info: dict[str, Any],
|
||||||
|
constants: CompiledFxGraphConstants,
|
||||||
|
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Cache specific post compile steps that need to run if we find a graph in the cache
|
||||||
|
This includes putting bundled triton artifacts in the right place,
|
||||||
|
reloading the PyCodeCache artifact, etc.
|
||||||
|
|
||||||
|
These don't always happen (i.e. on a cache miss, so they are in a separate function from
|
||||||
|
CompiledFxGraph.post_compile)
|
||||||
|
"""
|
||||||
|
if bundle := graph._triton_bundle:
|
||||||
|
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
|
||||||
|
if (meta := triton_bundler_meta) is not None:
|
||||||
|
cache_info["triton_bundler_meta"] = str(meta)
|
||||||
|
CompileEventLogger.try_add_pt2_compile(
|
||||||
|
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||||
|
)
|
||||||
|
CompileEventLogger.try_add_pt2_compile(
|
||||||
|
"AOTAutogradCache.inductor_load",
|
||||||
|
cached_kernel_names=meta.cached_kernel_names,
|
||||||
|
)
|
||||||
|
if len(meta.cached_kernel_names) > 0:
|
||||||
|
CompileEventLogger.try_(
|
||||||
|
CompileEventLogger.increment_toplevel, "num_triton_bundles"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
artifact_path = graph.after_deserialization(constants)
|
||||||
|
|
||||||
|
from .graph import GraphLowering
|
||||||
|
|
||||||
|
# This is used by tests to check the output for specific details.
|
||||||
|
if GraphLowering.save_output_code is not None:
|
||||||
|
GraphLowering.save_output_code(graph.source_code)
|
||||||
|
|
||||||
|
except OSError:
|
||||||
|
# Not expected, but in case the PyCodeCache entry is removed from
|
||||||
|
# underneath us, treat it as a cache miss and recompile.
|
||||||
|
return None, cache_info
|
||||||
|
|
||||||
|
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||||
|
code = graph.source_code
|
||||||
|
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||||
|
|
||||||
|
# Increment the cached metrics/counters by the amounts recorded when the FX
|
||||||
|
# graph was compiled for this cache entry. Pretending these counters
|
||||||
|
# were incremented normally is useful for testing with the cache enabled.
|
||||||
|
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
|
||||||
|
counters["inductor"] += graph.counter_deltas
|
||||||
|
|
||||||
|
output_code_log.debug("Output code: \n%s", code)
|
||||||
|
output_code_log.debug("Output code written to: %s", artifact_path)
|
||||||
|
# On cache hit, use artifact path as filename
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "fx_graph_runnable",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: graph.runnable_graph_str,
|
||||||
|
)
|
||||||
|
trace_structured(
|
||||||
|
"inductor_post_grad_graph",
|
||||||
|
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
||||||
|
)
|
||||||
|
trace_structured(
|
||||||
|
"inductor_output_code",
|
||||||
|
lambda: {"filename": artifact_path},
|
||||||
|
payload_fn=lambda: code,
|
||||||
|
)
|
||||||
|
return graph, cache_info
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lookup_graph(
|
def _lookup_graph(
|
||||||
key: str,
|
key: str,
|
||||||
|
|
@ -1136,40 +1212,6 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||||
CacheArtifactType.INDUCTOR, key, pickled_content
|
CacheArtifactType.INDUCTOR, key, pickled_content
|
||||||
)
|
)
|
||||||
|
|
||||||
if bundle := graph._triton_bundle:
|
|
||||||
triton_bundler_meta = TritonBundler.read_and_emit(bundle)
|
|
||||||
if (meta := triton_bundler_meta) is not None:
|
|
||||||
cache_info["triton_bundler_meta"] = str(meta)
|
|
||||||
CompileEventLogger.try_add_pt2_compile(
|
|
||||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
|
||||||
)
|
|
||||||
CompileEventLogger.try_add_pt2_compile(
|
|
||||||
"AOTAutogradCache.inductor_load",
|
|
||||||
cached_kernel_names=meta.cached_kernel_names,
|
|
||||||
)
|
|
||||||
if len(meta.cached_kernel_names) > 0:
|
|
||||||
CompileEventLogger.try_(
|
|
||||||
CompileEventLogger.increment_toplevel, "num_triton_bundles"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
artifact_path = graph.after_deserialization(constants)
|
|
||||||
|
|
||||||
from .graph import GraphLowering
|
|
||||||
|
|
||||||
# This is used by tests to check the output for specific details.
|
|
||||||
if GraphLowering.save_output_code is not None:
|
|
||||||
GraphLowering.save_output_code(graph.source_code)
|
|
||||||
|
|
||||||
except OSError:
|
|
||||||
# Not expected, but in case the PyCodeCache entry is removed from
|
|
||||||
# underneath us, treat it as a cache miss and recompile.
|
|
||||||
return None, cache_info
|
|
||||||
|
|
||||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
|
||||||
code = graph.source_code
|
|
||||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
|
||||||
|
|
||||||
# Now re-evaluate with the symints to add any guards to the current env.
|
# Now re-evaluate with the symints to add any guards to the current env.
|
||||||
if graph.guards_expr:
|
if graph.guards_expr:
|
||||||
check = bool(evaluate_guards(graph.guards_expr, symints))
|
check = bool(evaluate_guards(graph.guards_expr, symints))
|
||||||
|
|
@ -1178,33 +1220,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||||
"fx graph cache key %s post-load guards: %s", key, shape_env.guards
|
"fx graph cache key %s post-load guards: %s", key, shape_env.guards
|
||||||
)
|
)
|
||||||
|
|
||||||
# Increment the cached metrics/counters by the amounts recorded when the FX
|
return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants)
|
||||||
# graph was compiled for this cache entry. Pretending these counters
|
|
||||||
# were incremented normally is useful for testing with the cache enabled.
|
|
||||||
metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
|
|
||||||
counters["inductor"] += graph.counter_deltas
|
|
||||||
|
|
||||||
output_code_log.debug("Output code: \n%s", code)
|
|
||||||
output_code_log.debug("Output code written to: %s", artifact_path)
|
|
||||||
# On cache hit, use artifact path as filename
|
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "fx_graph_runnable",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: graph.runnable_graph_str,
|
|
||||||
)
|
|
||||||
trace_structured(
|
|
||||||
"inductor_post_grad_graph",
|
|
||||||
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
|
||||||
)
|
|
||||||
trace_structured(
|
|
||||||
"inductor_output_code",
|
|
||||||
lambda: {"filename": artifact_path},
|
|
||||||
payload_fn=lambda: code,
|
|
||||||
)
|
|
||||||
return graph, cache_info
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_to_local_cache(key: str, content: bytes) -> None:
|
def _write_to_local_cache(key: str, content: bytes) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user