Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)

This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry.

Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable.

Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle.

Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache.

The upsides of this are many:
- We no longer need to call into a separate process on cache hit
- We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic
- Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic.

Fixes #149449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054
Approved by: https://github.com/oulgen
ghstack dependencies: #149657
This commit is contained in:
James Wu 2025-03-27 05:09:29 -07:00 committed by PyTorch MergeBot
parent 6eac3a0068
commit ac91f8765b
7 changed files with 311 additions and 40 deletions

View File

@ -93,12 +93,16 @@ class TestFxGraphCache(TestCase):
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"compile_threads": 1})
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@parametrize("bundle_triton", (False, True))
@parametrize("use_static_cuda_launcher", (False, True))
@parametrize("grad", (False, True))
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad):
def test_cache_load_function(
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher, grad
):
"""
Verify that we can populate and load functions from the cache.
"""
@ -106,6 +110,10 @@ class TestFxGraphCache(TestCase):
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
grad_multiplier = 2 if grad else 1
@ -116,7 +124,10 @@ class TestFxGraphCache(TestCase):
a_orig = torch.rand(25, dtype=dtype, device=device)
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
use_static_cuda_launcher=use_static_cuda_launcher,
):
compiled_fn = torch.compile(fn, dynamic=dynamic)
a1 = a_orig.clone().requires_grad_(grad)
@ -149,6 +160,14 @@ class TestFxGraphCache(TestCase):
self.assertEqual(
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"], 0
)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
@ -189,6 +208,15 @@ class TestFxGraphCache(TestCase):
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
grad_multiplier * read_and_emit_kernel_count,
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
self.reset()
@ -228,6 +256,15 @@ class TestFxGraphCache(TestCase):
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
grad_multiplier * read_and_emit_kernel_count,
)
if use_static_cuda_launcher:
self.assertEqual(
counters["inductor"]["triton_bundler_save_static_autotuner"],
grad_multiplier * 2 if device == "cuda" else 0,
)
self.assertEqual(
counters["inductor"]["triton_bundler_load_static_autotuner"],
grad_multiplier if device == "cuda" else 0,
)
@requires_triton()
@config.patch({"fx_graph_remote_cache": True})
@ -235,13 +272,23 @@ class TestFxGraphCache(TestCase):
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("dynamic", (False, True))
@parametrize("bundle_triton", (False, True))
def test_remote_cache_load_function(self, device, dtype, dynamic, bundle_triton):
@parametrize("use_static_cuda_launcher", (False, True))
@config.patch(
{"compile_threads": 1}
) # Can't check globalStats if there are workers
def test_remote_cache_load_function(
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher
):
from unittest.mock import patch
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
def fn(x, y):
return (x * 2, y @ y)
@ -253,6 +300,7 @@ class TestFxGraphCache(TestCase):
{
"fx_graph_remote_cache": True,
"bundle_triton_into_fx_graph_cache": bundle_triton,
"use_static_cuda_launcher": use_static_cuda_launcher,
}
), patch.dict(os.environ), PatchCaches():
os.environ.pop("TRITON_CACHE_MANAGER", None)
@ -768,7 +816,9 @@ class TestFxGraphCache(TestCase):
return torch.cond(x.shape[0], true_fn, false_fn, (x,))
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
):
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)
x = torch.randn(4, 4, device=GPU_TYPE)
@ -933,8 +983,10 @@ class TestFxGraphCache(TestCase):
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"compile_threads": 1})
@parametrize("bundle_triton", (False, True))
def test_triton_op(self, bundle_triton):
@parametrize("use_static_cuda_launcher", (False, True))
def test_triton_op(self, bundle_triton, use_static_cuda_launcher):
libname = "my_cool_namespace"
opname = "my_triton_operator"
@ -952,7 +1004,12 @@ class TestFxGraphCache(TestCase):
def f(x, y):
return add(x, y)
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compile_threads = 1 if use_static_cuda_launcher else config.compile_threads
with config.patch(
bundle_triton_into_fx_graph_cache=bundle_triton,
use_static_cuda_launcher=use_static_cuda_launcher,
compile_threads=compile_threads,
):
compiled_fn = torch.compile(f, fullgraph=True)
x = torch.randn(4, device=GPU_TYPE)

View File

@ -32,6 +32,7 @@ from torch._inductor.codecache import (
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
StaticAutotunerFuture,
torch_key,
)
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
@ -148,7 +149,7 @@ class CompiledTritonKernels:
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
"""
_cache: dict[str, LambdaFuture] = {}
_cache: dict[str, CodeCacheFuture] = {}
@staticmethod
def key(kernel_src: str):
@ -161,7 +162,7 @@ class CompiledTritonKernels:
return code_hash(kernel_src, extra=torch_key())
@staticmethod
def save(kernel_src: str, future: LambdaFuture):
def save(kernel_src: str, future: CodeCacheFuture):
"""
Saves a compiled triton kernel to the cache.
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
@ -174,9 +175,9 @@ class CompiledTritonKernels:
CompiledTritonKernels._cache[key] = future
@staticmethod
def get(kernel_src: str, default: Any) -> LambdaFuture:
def get(kernel_src: str) -> Optional[CodeCacheFuture]:
key = CompiledTritonKernels.key(kernel_src)
return CompiledTritonKernels._cache.get(key, default)
return CompiledTritonKernels._cache.get(key, None)
@staticmethod
def cache_clear():
@ -185,6 +186,8 @@ class CompiledTritonKernels:
@staticmethod
def remove_future(kernel_src: str) -> None:
key = CompiledTritonKernels.key(kernel_src)
# Delete the LambdaFuture if there is one
if key in CompiledTritonKernels._cache:
del CompiledTritonKernels._cache[key]
@ -282,9 +285,14 @@ class AsyncCompile:
- The AutotuneCache, if enabled, is constructed on each worker per triton config
and pickled by to us via `CachingAutotuner.save_cache_hook`.
"""
if future := CompiledTritonKernels.get(source_code, None):
counters["inductor"]["async_compile_cache_hit"] += 1
return future
load_kernel = functools.partial(
_load_triton_kernel_from_source, kernel_name, source_code
)
def reload_kernel_in_parent():
# Benchmark how often this happens
with dynamo_timed("reload_kernel_in_parent"):
return load_kernel()
counters["inductor"]["async_compile_cache_miss"] += 1
@ -296,15 +304,22 @@ class AsyncCompile:
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
)
load_kernel = functools.partial(
_load_triton_kernel_from_source, kernel_name, source_code
)
is_parallel = self.use_process_pool()
set_feature_use("parallel_compile_post_warmup", is_parallel)
compile_id = torch._guards.CompileContext.current_compile_id()
is_backward = getattr(V.graph, "is_backward", False)
if (future := CompiledTritonKernels.get(source_code)) is not None:
counters["inductor"]["async_compile_cache_hit"] += 1
# Set reload_kernel_from_src properly based on source_code
if isinstance(future, StaticAutotunerFuture):
future.reload_kernel_from_src = reload_kernel_in_parent
if is_parallel:
return future
else:
return future.result()
if is_parallel:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
@ -317,19 +332,16 @@ class AsyncCompile:
extra_env,
)
def reload_kernel_in_parent():
# Benchmark how often this happens
with dynamo_timed("reload_kernel_in_parent"):
return load_kernel()
def get_result() -> tuple[CachingAutotuner, int]:
def get_result() -> CachingAutotuner:
kernel, elapsed_us = task.result()
# Now that we've compiled, we should clear the future
# so it can't be used again
CompiledTritonKernels.remove_future(source_code)
kernel.set_compile_info(compile_id, is_backward)
CompiledTritonKernels.remove_future(source_code)
kernel.precompile(
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
warm_cache_only=False,
reload_kernel=reload_kernel_in_parent,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
@ -350,7 +362,10 @@ class AsyncCompile:
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(warm_cache_only=False)
kernel.precompile(
warm_cache_only=False,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
@ -444,7 +459,6 @@ class AsyncCompile:
disable=config.disable_progress,
delay=0,
)
for key, result in kernels.items():
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(key)

View File

@ -3351,3 +3351,28 @@ class LambdaFuture(CodeCacheFuture):
def result(self) -> Callable[..., Any]: # type: ignore[override]
return self.result_fn()
class StaticAutotunerFuture(CodeCacheFuture):
"""
A statically launchable CachingAutotuner, loaded from TritonBundler
"""
def __init__(self, static_autotuner: CachingAutotuner) -> None:
# Pickled version of CachingAutotuner
self.static_autotuner = static_autotuner
# This needs to be set in AsyncCompile.triton, in case
# we need to reload the CachingAutotuner from its source code
# We don't store the source code on the CachingAutotuner itself
# since it can be very large.
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
def result(self) -> CachingAutotuner:
assert self.reload_kernel_from_src is not None
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
self.static_autotuner.precompile( # type: ignore[union-attr]
warm_cache_only=False,
reload_kernel=self.reload_kernel_from_src,
static_triton_bundle_key=None, # no need to save again
)
return self.static_autotuner

View File

@ -65,7 +65,7 @@ if TYPE_CHECKING:
from torch._library.fake_class_registry import FakeScriptObject
from .compile_fx import _CompileFxKwargs
from .triton_bundler import TritonKernelArtifacts
from .triton_bundler import TritonBundle
log = logging.getLogger(__name__)
@ -420,7 +420,7 @@ class CompiledFxGraph(OutputCode):
inputs_to_check: Sequence[int]
_boxed_call: Optional[bool] = None
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
_triton_bundle: Optional[TritonBundle] = None
def __init__(
self,

View File

@ -85,12 +85,17 @@ class StaticallyLaunchedCudaKernel:
def load_kernel(self, device: int) -> None:
from torch._C import _StaticCudaLauncher
assert hasattr(self, "cubin_path")
if self.function is not None:
return
assert hasattr(self, "cubin_path")
assert self.cubin_path is not None
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
self.cubin_path, self.name, self.shared, device
)
# Don't need the cubin path anymore now that we've loaded
self.cubin_path = None
@staticmethod
@functools.lru_cache
@ -161,6 +166,15 @@ class StaticallyLaunchedCudaKernel:
params.append(self.extract_type(ty))
return "".join(params)
def __getstate__(self) -> dict[str, Any]:
# Remove objects that are no longer valid for pickling
state = self.__dict__.copy()
state["function"] = None
# Cubin paths aren't consistent across processes, so we clear
# and reload them.
state["cubin_path"] = None
return state
def run(
self,
grid_x: int,
@ -190,6 +204,7 @@ class StaticallyLaunchedCudaKernel:
# TODO: can handle grid functions here or in C++, so
# that we don't need the grid handler above.
_StaticCudaLauncher._launch_kernel(
self.function,
grid_x,

View File

@ -275,6 +275,17 @@ class CachingAutotuner(KernelInterface):
self.compile_id: Optional[CompileId] = None
self.is_backward = False
def is_statically_launchable(self):
"""
Checks if every compiled kernel is statically launchable, which
allows us to efficiently cache it in FXGraphCache
"""
if not self.compile_results:
return False
return all(
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
)
def set_compile_info(
self, compile_id: Optional[CompileId], is_backward: bool
) -> None:
@ -285,6 +296,7 @@ class CachingAutotuner(KernelInterface):
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
static_triton_bundle_key: Optional[str] = None,
):
if warm_cache_only:
self._precompile_worker()
@ -297,6 +309,8 @@ class CachingAutotuner(KernelInterface):
if reload_kernel is not None:
self._reload_kernel = reload_kernel
self._precompile_worker()
if static_triton_bundle_key is not None and self.is_statically_launchable():
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
self._make_launchers()
self._dynamic_scale_rblock()
@ -462,15 +476,24 @@ class CachingAutotuner(KernelInterface):
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
self.launchers = launchers
def prepare_for_pickle(self):
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
"""Drop stuff from triton.JITFunction that does not pickle.
This must be called after precompile so that these things are no longer needed.
Returns a tuple of old values
"""
old_values = (
self.fn.fn,
self.fn.__globals__,
self.fn.used_global_vals,
self.fn.repr,
self.launchers,
)
self.fn.fn = None
self.fn.__globals__ = None
self.fn.used_global_vals = None
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
self.launchers = []
return old_values
def __getstate__(self) -> dict[str, Any]:
assert not self.launchers, (
@ -1056,7 +1079,8 @@ class CompileResult(Generic[_T]):
f" grid_2 = {grid.z_grid}",
f" runner({', '.join(runner_args)})",
]
exec("\n".join(lines), scope)
launcher_code = "\n".join(lines)
exec(launcher_code, scope)
return scope["launcher"]
def _get_arg_lists(
@ -1198,8 +1222,26 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
raise e
return None
def reload_cubin_path(self):
"""
When loading from cache on disk, we want to reload cubin
files from their appropriate location on disc.
"""
cubin_location = os.path.join(
triton_cache_dir(self.compile_meta.get("device", 0)),
triton_hash_to_path_key(self.kernel.hash),
f"{self.kernel.name}.cubin",
)
if not os.path.exists(cubin_location):
raise RuntimeError(
"Cubin file saved by TritonBundler not found at %s", cubin_location
)
self.kernel.cubin_path = cubin_location
def make_launcher(self) -> LauncherType:
# Load the binary on the parent
if not self.kernel.cubin_path:
self.reload_cubin_path()
self.kernel.load_kernel(self.compile_meta.get("device", 0))
scope = {
"runner": self.kernel.run,

View File

@ -1,3 +1,4 @@
import copy
import dataclasses
import logging
import os
@ -42,6 +43,21 @@ class TritonKernelArtifact:
payload: bytes = dataclasses.field(repr=False) # Do not display binary
@dataclasses.dataclass(frozen=True)
class StaticallyLaunchedAutotuner:
"""
Represents a statically compiled CachingAutotuner object that we can
save directly in the cache. A CachingAutotuner is made up of a list of
StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact.
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
"""
cache_key: str
kernel_name: str
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
@dataclasses.dataclass(frozen=True)
class TritonKernelArtifacts:
"""
@ -60,6 +76,17 @@ class TritonBundlerMetadata:
"""
cached_kernel_names: list[str]
statically_launched_kernel_names: list[str]
@dataclasses.dataclass(frozen=True)
class TritonBundle:
"""
Serializable bundle to save into FXGraphCache
"""
kernel_artifacts: list[TritonKernelArtifacts]
static_autotuners: list[StaticallyLaunchedAutotuner]
class TritonBundler:
@ -79,6 +106,7 @@ class TritonBundler:
"""
_entries: Optional[list[TritonBundleEntry]] = None
_static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None
# __grp__kernel_name.json contains metadata with source code paths
# we use this as sentinal value for search and replace
@ -112,6 +140,7 @@ class TritonBundler:
log.debug("TritonBundler.begin_compile is called")
assert cls._entries is None
cls._entries = []
cls._static_autotuners = []
@classmethod
def end_compile(cls) -> None:
@ -121,6 +150,7 @@ class TritonBundler:
"""
log.debug("TritonBundler.end_compile is called")
cls._entries = None
cls._static_autotuners = None
@classmethod
def put(cls, kernel_hash: str, device: int) -> None:
@ -133,20 +163,93 @@ class TritonBundler:
TritonBundleEntry(kernel_hash, device, triton_cache_dir(device))
)
@classmethod
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
from torch._inductor import config
assert config.use_static_cuda_launcher
if (entries := cls._static_autotuners) is not None:
# Clear a bunch of unpicklable values and make a copy to save
# for FXGraphCache
old_values = kernel.prepare_for_pickle()
new_kernel = copy.deepcopy(kernel)
new_kernel._reload_kernel = None
entries.append(
StaticallyLaunchedAutotuner(
key,
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
new_kernel,
)
)
# Put the values back since we need it to use now
(
kernel.fn.fn,
kernel.fn.__globals__,
kernel.fn.used_global_vals,
kernel.fn.repr,
kernel.launchers,
) = old_values
@classmethod
def collect_static_autotuners(
cls,
) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]:
if not cls._static_autotuners:
return [], []
else:
log.info(
"Saving %d statically launchable CachingAutotuners",
len(cls._static_autotuners),
)
static_autotuner_names = [i.kernel_name for i in cls._static_autotuners]
counters["inductor"]["triton_bundler_save_static_autotuner"] += 1
return cls._static_autotuners, static_autotuner_names
@classmethod
def load_autotuners(
cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]]
) -> list[str]:
"""
Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels
cache.
"""
if not static_autotuners:
return []
from torch._inductor.async_compile import CompiledTritonKernels
from torch._inductor.codecache import StaticAutotunerFuture
log.info("Loading %d statically launchable autotuners", len(static_autotuners))
kernel_names = []
with dynamo_timed("TritonBundler.load_cached_static_autotuners"):
for result in static_autotuners:
# We make a future instead of returning the kernel here so that
# kernels that are not statically launchable (i.e. cache miss)
# can launch a worker without waiting on the blocking step of
# StaticAutotunerFuture.result().
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
result.kernel
)
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
kernel_names.append(result.kernel_name)
return kernel_names
@classmethod
def collect(
cls,
) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]:
"""
This is the main function called when a cache write happens. This function
converts all the previously remembered kernels into bundled format so that
it can be written into a cache entry.
This function also finalizes the current bundle.
"""
from torch._inductor import config
if not TritonBundler.is_enabled():
cls.end_compile()
set_feature_use("triton_bundling", False)
return [], None
return TritonBundle([], []), None
set_feature_use("triton_bundling", True)
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
@ -199,14 +302,21 @@ class TritonBundler:
artifacts,
)
)
if config.use_static_cuda_launcher:
static_autotuners, static_kernel_names = (
cls.collect_static_autotuners()
)
else:
static_autotuners = []
static_kernel_names = []
cls.end_compile()
return result, TritonBundlerMetadata(kernel_names)
return [], None
return TritonBundle(result, static_autotuners), TritonBundlerMetadata(
kernel_names, static_kernel_names
)
return TritonBundle([], []), None
@staticmethod
def read_and_emit(
bundle: list[TritonKernelArtifacts],
) -> Optional[TritonBundlerMetadata]:
def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]:
"""
This is the main function called when a cache read happens. This function
converts the bundled format back into individual files and writes them
@ -219,6 +329,8 @@ class TritonBundler:
Exclusive access means that no other process should be writing to
or reading from the target directory.
"""
from torch._inductor import config
if not TritonBundler.is_enabled():
return None
@ -227,7 +339,7 @@ class TritonBundler:
):
kernel_names: list[str] = []
for artifacts in bundle:
for artifacts in bundle.kernel_artifacts:
basedir = triton_cache_dir(artifacts.device)
directory = os.path.join(basedir, artifacts.kernel_hash)
@ -272,4 +384,10 @@ class TritonBundler:
# Atomic on POSIX systems
os.replace(tmp_dir, directory)
return TritonBundlerMetadata(kernel_names)
if config.use_static_cuda_launcher:
static_kernel_names = TritonBundler.load_autotuners(
bundle.static_autotuners
)
else:
static_kernel_names = []
return TritonBundlerMetadata(kernel_names, static_kernel_names)