mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)
This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry. Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable. Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle. Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache. The upsides of this are many: - We no longer need to call into a separate process on cache hit - We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic - Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic. Fixes #149449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054 Approved by: https://github.com/oulgen ghstack dependencies: #149657
This commit is contained in:
parent
6eac3a0068
commit
ac91f8765b
|
|
@ -93,12 +93,16 @@ class TestFxGraphCache(TestCase):
|
|||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
@parametrize("grad", (False, True))
|
||||
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad):
|
||||
def test_cache_load_function(
|
||||
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher, grad
|
||||
):
|
||||
"""
|
||||
Verify that we can populate and load functions from the cache.
|
||||
"""
|
||||
|
|
@ -106,6 +110,10 @@ class TestFxGraphCache(TestCase):
|
|||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
|
||||
raise unittest.SkipTest(
|
||||
"Static cuda launcher requires cuda and triton bundling"
|
||||
)
|
||||
|
||||
grad_multiplier = 2 if grad else 1
|
||||
|
||||
|
|
@ -116,7 +124,10 @@ class TestFxGraphCache(TestCase):
|
|||
a_orig = torch.rand(25, dtype=dtype, device=device)
|
||||
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
|
||||
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
use_static_cuda_launcher=use_static_cuda_launcher,
|
||||
):
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
a1 = a_orig.clone().requires_grad_(grad)
|
||||
|
|
@ -149,6 +160,14 @@ class TestFxGraphCache(TestCase):
|
|||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"], 0
|
||||
)
|
||||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
|
|
@ -189,6 +208,15 @@ class TestFxGraphCache(TestCase):
|
|||
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
|
||||
grad_multiplier * read_and_emit_kernel_count,
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
|
|
@ -228,6 +256,15 @@ class TestFxGraphCache(TestCase):
|
|||
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
|
||||
grad_multiplier * read_and_emit_kernel_count,
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier * 2 if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_remote_cache": True})
|
||||
|
|
@ -235,13 +272,23 @@ class TestFxGraphCache(TestCase):
|
|||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
def test_remote_cache_load_function(self, device, dtype, dynamic, bundle_triton):
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Can't check globalStats if there are workers
|
||||
def test_remote_cache_load_function(
|
||||
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher
|
||||
):
|
||||
from unittest.mock import patch
|
||||
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
|
||||
raise unittest.SkipTest(
|
||||
"Static cuda launcher requires cuda and triton bundling"
|
||||
)
|
||||
|
||||
def fn(x, y):
|
||||
return (x * 2, y @ y)
|
||||
|
|
@ -253,6 +300,7 @@ class TestFxGraphCache(TestCase):
|
|||
{
|
||||
"fx_graph_remote_cache": True,
|
||||
"bundle_triton_into_fx_graph_cache": bundle_triton,
|
||||
"use_static_cuda_launcher": use_static_cuda_launcher,
|
||||
}
|
||||
), patch.dict(os.environ), PatchCaches():
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
|
|
@ -768,7 +816,9 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
return torch.cond(x.shape[0], true_fn, false_fn, (x,))
|
||||
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
):
|
||||
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, 4, device=GPU_TYPE)
|
||||
|
|
@ -933,8 +983,10 @@ class TestFxGraphCache(TestCase):
|
|||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
def test_triton_op(self, bundle_triton):
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
def test_triton_op(self, bundle_triton, use_static_cuda_launcher):
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
|
||||
|
|
@ -952,7 +1004,12 @@ class TestFxGraphCache(TestCase):
|
|||
def f(x, y):
|
||||
return add(x, y)
|
||||
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compile_threads = 1 if use_static_cuda_launcher else config.compile_threads
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
use_static_cuda_launcher=use_static_cuda_launcher,
|
||||
compile_threads=compile_threads,
|
||||
):
|
||||
compiled_fn = torch.compile(f, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from torch._inductor.codecache import (
|
|||
HalideCodeCache,
|
||||
LambdaFuture,
|
||||
ROCmCodeCache,
|
||||
StaticAutotunerFuture,
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
|
|
@ -148,7 +149,7 @@ class CompiledTritonKernels:
|
|||
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
|
||||
"""
|
||||
|
||||
_cache: dict[str, LambdaFuture] = {}
|
||||
_cache: dict[str, CodeCacheFuture] = {}
|
||||
|
||||
@staticmethod
|
||||
def key(kernel_src: str):
|
||||
|
|
@ -161,7 +162,7 @@ class CompiledTritonKernels:
|
|||
return code_hash(kernel_src, extra=torch_key())
|
||||
|
||||
@staticmethod
|
||||
def save(kernel_src: str, future: LambdaFuture):
|
||||
def save(kernel_src: str, future: CodeCacheFuture):
|
||||
"""
|
||||
Saves a compiled triton kernel to the cache.
|
||||
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
|
||||
|
|
@ -174,9 +175,9 @@ class CompiledTritonKernels:
|
|||
CompiledTritonKernels._cache[key] = future
|
||||
|
||||
@staticmethod
|
||||
def get(kernel_src: str, default: Any) -> LambdaFuture:
|
||||
def get(kernel_src: str) -> Optional[CodeCacheFuture]:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
return CompiledTritonKernels._cache.get(key, default)
|
||||
return CompiledTritonKernels._cache.get(key, None)
|
||||
|
||||
@staticmethod
|
||||
def cache_clear():
|
||||
|
|
@ -185,6 +186,8 @@ class CompiledTritonKernels:
|
|||
@staticmethod
|
||||
def remove_future(kernel_src: str) -> None:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
|
||||
# Delete the LambdaFuture if there is one
|
||||
if key in CompiledTritonKernels._cache:
|
||||
del CompiledTritonKernels._cache[key]
|
||||
|
||||
|
|
@ -282,9 +285,14 @@ class AsyncCompile:
|
|||
- The AutotuneCache, if enabled, is constructed on each worker per triton config
|
||||
and pickled by to us via `CachingAutotuner.save_cache_hook`.
|
||||
"""
|
||||
if future := CompiledTritonKernels.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
|
||||
|
|
@ -296,15 +304,22 @@ class AsyncCompile:
|
|||
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
||||
)
|
||||
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
is_parallel = self.use_process_pool()
|
||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||
|
||||
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||
is_backward = getattr(V.graph, "is_backward", False)
|
||||
|
||||
if (future := CompiledTritonKernels.get(source_code)) is not None:
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
# Set reload_kernel_from_src properly based on source_code
|
||||
if isinstance(future, StaticAutotunerFuture):
|
||||
future.reload_kernel_from_src = reload_kernel_in_parent
|
||||
if is_parallel:
|
||||
return future
|
||||
else:
|
||||
return future.result()
|
||||
|
||||
if is_parallel:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
|
|
@ -317,19 +332,16 @@ class AsyncCompile:
|
|||
extra_env,
|
||||
)
|
||||
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
|
||||
def get_result() -> tuple[CachingAutotuner, int]:
|
||||
def get_result() -> CachingAutotuner:
|
||||
kernel, elapsed_us = task.result()
|
||||
# Now that we've compiled, we should clear the future
|
||||
# so it can't be used again
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||
warm_cache_only=False,
|
||||
reload_kernel=reload_kernel_in_parent,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
|
|
@ -350,7 +362,10 @@ class AsyncCompile:
|
|||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(warm_cache_only=False)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
|
|
@ -444,7 +459,6 @@ class AsyncCompile:
|
|||
disable=config.disable_progress,
|
||||
delay=0,
|
||||
)
|
||||
|
||||
for key, result in kernels.items():
|
||||
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
|
||||
pbar.set_postfix_str(key)
|
||||
|
|
|
|||
|
|
@ -3351,3 +3351,28 @@ class LambdaFuture(CodeCacheFuture):
|
|||
|
||||
def result(self) -> Callable[..., Any]: # type: ignore[override]
|
||||
return self.result_fn()
|
||||
|
||||
|
||||
class StaticAutotunerFuture(CodeCacheFuture):
|
||||
"""
|
||||
A statically launchable CachingAutotuner, loaded from TritonBundler
|
||||
"""
|
||||
|
||||
def __init__(self, static_autotuner: CachingAutotuner) -> None:
|
||||
# Pickled version of CachingAutotuner
|
||||
self.static_autotuner = static_autotuner
|
||||
# This needs to be set in AsyncCompile.triton, in case
|
||||
# we need to reload the CachingAutotuner from its source code
|
||||
# We don't store the source code on the CachingAutotuner itself
|
||||
# since it can be very large.
|
||||
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
|
||||
|
||||
def result(self) -> CachingAutotuner:
|
||||
assert self.reload_kernel_from_src is not None
|
||||
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
|
||||
self.static_autotuner.precompile( # type: ignore[union-attr]
|
||||
warm_cache_only=False,
|
||||
reload_kernel=self.reload_kernel_from_src,
|
||||
static_triton_bundle_key=None, # no need to save again
|
||||
)
|
||||
return self.static_autotuner
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ if TYPE_CHECKING:
|
|||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .triton_bundler import TritonKernelArtifacts
|
||||
from .triton_bundler import TritonBundle
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -420,7 +420,7 @@ class CompiledFxGraph(OutputCode):
|
|||
inputs_to_check: Sequence[int]
|
||||
|
||||
_boxed_call: Optional[bool] = None
|
||||
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
|
||||
_triton_bundle: Optional[TritonBundle] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -85,12 +85,17 @@ class StaticallyLaunchedCudaKernel:
|
|||
def load_kernel(self, device: int) -> None:
|
||||
from torch._C import _StaticCudaLauncher
|
||||
|
||||
assert hasattr(self, "cubin_path")
|
||||
if self.function is not None:
|
||||
return
|
||||
|
||||
assert hasattr(self, "cubin_path")
|
||||
assert self.cubin_path is not None
|
||||
|
||||
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
|
||||
self.cubin_path, self.name, self.shared, device
|
||||
)
|
||||
# Don't need the cubin path anymore now that we've loaded
|
||||
self.cubin_path = None
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
|
|
@ -161,6 +166,15 @@ class StaticallyLaunchedCudaKernel:
|
|||
params.append(self.extract_type(ty))
|
||||
return "".join(params)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
# Remove objects that are no longer valid for pickling
|
||||
state = self.__dict__.copy()
|
||||
state["function"] = None
|
||||
# Cubin paths aren't consistent across processes, so we clear
|
||||
# and reload them.
|
||||
state["cubin_path"] = None
|
||||
return state
|
||||
|
||||
def run(
|
||||
self,
|
||||
grid_x: int,
|
||||
|
|
@ -190,6 +204,7 @@ class StaticallyLaunchedCudaKernel:
|
|||
|
||||
# TODO: can handle grid functions here or in C++, so
|
||||
# that we don't need the grid handler above.
|
||||
|
||||
_StaticCudaLauncher._launch_kernel(
|
||||
self.function,
|
||||
grid_x,
|
||||
|
|
|
|||
|
|
@ -275,6 +275,17 @@ class CachingAutotuner(KernelInterface):
|
|||
self.compile_id: Optional[CompileId] = None
|
||||
self.is_backward = False
|
||||
|
||||
def is_statically_launchable(self):
|
||||
"""
|
||||
Checks if every compiled kernel is statically launchable, which
|
||||
allows us to efficiently cache it in FXGraphCache
|
||||
"""
|
||||
if not self.compile_results:
|
||||
return False
|
||||
return all(
|
||||
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
|
||||
)
|
||||
|
||||
def set_compile_info(
|
||||
self, compile_id: Optional[CompileId], is_backward: bool
|
||||
) -> None:
|
||||
|
|
@ -285,6 +296,7 @@ class CachingAutotuner(KernelInterface):
|
|||
self,
|
||||
warm_cache_only=False,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
static_triton_bundle_key: Optional[str] = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
|
|
@ -297,6 +309,8 @@ class CachingAutotuner(KernelInterface):
|
|||
if reload_kernel is not None:
|
||||
self._reload_kernel = reload_kernel
|
||||
self._precompile_worker()
|
||||
if static_triton_bundle_key is not None and self.is_statically_launchable():
|
||||
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
|
||||
self._make_launchers()
|
||||
self._dynamic_scale_rblock()
|
||||
|
||||
|
|
@ -462,15 +476,24 @@ class CachingAutotuner(KernelInterface):
|
|||
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
||||
self.launchers = launchers
|
||||
|
||||
def prepare_for_pickle(self):
|
||||
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
|
||||
"""Drop stuff from triton.JITFunction that does not pickle.
|
||||
This must be called after precompile so that these things are no longer needed.
|
||||
Returns a tuple of old values
|
||||
"""
|
||||
old_values = (
|
||||
self.fn.fn,
|
||||
self.fn.__globals__,
|
||||
self.fn.used_global_vals,
|
||||
self.fn.repr,
|
||||
self.launchers,
|
||||
)
|
||||
self.fn.fn = None
|
||||
self.fn.__globals__ = None
|
||||
self.fn.used_global_vals = None
|
||||
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
||||
self.launchers = []
|
||||
return old_values
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
assert not self.launchers, (
|
||||
|
|
@ -1056,7 +1079,8 @@ class CompileResult(Generic[_T]):
|
|||
f" grid_2 = {grid.z_grid}",
|
||||
f" runner({', '.join(runner_args)})",
|
||||
]
|
||||
exec("\n".join(lines), scope)
|
||||
launcher_code = "\n".join(lines)
|
||||
exec(launcher_code, scope)
|
||||
return scope["launcher"]
|
||||
|
||||
def _get_arg_lists(
|
||||
|
|
@ -1198,8 +1222,26 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
|||
raise e
|
||||
return None
|
||||
|
||||
def reload_cubin_path(self):
|
||||
"""
|
||||
When loading from cache on disk, we want to reload cubin
|
||||
files from their appropriate location on disc.
|
||||
"""
|
||||
cubin_location = os.path.join(
|
||||
triton_cache_dir(self.compile_meta.get("device", 0)),
|
||||
triton_hash_to_path_key(self.kernel.hash),
|
||||
f"{self.kernel.name}.cubin",
|
||||
)
|
||||
if not os.path.exists(cubin_location):
|
||||
raise RuntimeError(
|
||||
"Cubin file saved by TritonBundler not found at %s", cubin_location
|
||||
)
|
||||
self.kernel.cubin_path = cubin_location
|
||||
|
||||
def make_launcher(self) -> LauncherType:
|
||||
# Load the binary on the parent
|
||||
if not self.kernel.cubin_path:
|
||||
self.reload_cubin_path()
|
||||
self.kernel.load_kernel(self.compile_meta.get("device", 0))
|
||||
scope = {
|
||||
"runner": self.kernel.run,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -42,6 +43,21 @@ class TritonKernelArtifact:
|
|||
payload: bytes = dataclasses.field(repr=False) # Do not display binary
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class StaticallyLaunchedAutotuner:
|
||||
"""
|
||||
Represents a statically compiled CachingAutotuner object that we can
|
||||
save directly in the cache. A CachingAutotuner is made up of a list of
|
||||
StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact.
|
||||
|
||||
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
|
||||
"""
|
||||
|
||||
cache_key: str
|
||||
kernel_name: str
|
||||
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonKernelArtifacts:
|
||||
"""
|
||||
|
|
@ -60,6 +76,17 @@ class TritonBundlerMetadata:
|
|||
"""
|
||||
|
||||
cached_kernel_names: list[str]
|
||||
statically_launched_kernel_names: list[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonBundle:
|
||||
"""
|
||||
Serializable bundle to save into FXGraphCache
|
||||
"""
|
||||
|
||||
kernel_artifacts: list[TritonKernelArtifacts]
|
||||
static_autotuners: list[StaticallyLaunchedAutotuner]
|
||||
|
||||
|
||||
class TritonBundler:
|
||||
|
|
@ -79,6 +106,7 @@ class TritonBundler:
|
|||
"""
|
||||
|
||||
_entries: Optional[list[TritonBundleEntry]] = None
|
||||
_static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None
|
||||
|
||||
# __grp__kernel_name.json contains metadata with source code paths
|
||||
# we use this as sentinal value for search and replace
|
||||
|
|
@ -112,6 +140,7 @@ class TritonBundler:
|
|||
log.debug("TritonBundler.begin_compile is called")
|
||||
assert cls._entries is None
|
||||
cls._entries = []
|
||||
cls._static_autotuners = []
|
||||
|
||||
@classmethod
|
||||
def end_compile(cls) -> None:
|
||||
|
|
@ -121,6 +150,7 @@ class TritonBundler:
|
|||
"""
|
||||
log.debug("TritonBundler.end_compile is called")
|
||||
cls._entries = None
|
||||
cls._static_autotuners = None
|
||||
|
||||
@classmethod
|
||||
def put(cls, kernel_hash: str, device: int) -> None:
|
||||
|
|
@ -133,20 +163,93 @@ class TritonBundler:
|
|||
TritonBundleEntry(kernel_hash, device, triton_cache_dir(device))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
from torch._inductor import config
|
||||
|
||||
assert config.use_static_cuda_launcher
|
||||
if (entries := cls._static_autotuners) is not None:
|
||||
# Clear a bunch of unpicklable values and make a copy to save
|
||||
# for FXGraphCache
|
||||
old_values = kernel.prepare_for_pickle()
|
||||
new_kernel = copy.deepcopy(kernel)
|
||||
new_kernel._reload_kernel = None
|
||||
entries.append(
|
||||
StaticallyLaunchedAutotuner(
|
||||
key,
|
||||
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
|
||||
new_kernel,
|
||||
)
|
||||
)
|
||||
# Put the values back since we need it to use now
|
||||
(
|
||||
kernel.fn.fn,
|
||||
kernel.fn.__globals__,
|
||||
kernel.fn.used_global_vals,
|
||||
kernel.fn.repr,
|
||||
kernel.launchers,
|
||||
) = old_values
|
||||
|
||||
@classmethod
|
||||
def collect_static_autotuners(
|
||||
cls,
|
||||
) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]:
|
||||
if not cls._static_autotuners:
|
||||
return [], []
|
||||
else:
|
||||
log.info(
|
||||
"Saving %d statically launchable CachingAutotuners",
|
||||
len(cls._static_autotuners),
|
||||
)
|
||||
static_autotuner_names = [i.kernel_name for i in cls._static_autotuners]
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"] += 1
|
||||
return cls._static_autotuners, static_autotuner_names
|
||||
|
||||
@classmethod
|
||||
def load_autotuners(
|
||||
cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels
|
||||
cache.
|
||||
"""
|
||||
if not static_autotuners:
|
||||
return []
|
||||
|
||||
from torch._inductor.async_compile import CompiledTritonKernels
|
||||
from torch._inductor.codecache import StaticAutotunerFuture
|
||||
|
||||
log.info("Loading %d statically launchable autotuners", len(static_autotuners))
|
||||
kernel_names = []
|
||||
with dynamo_timed("TritonBundler.load_cached_static_autotuners"):
|
||||
for result in static_autotuners:
|
||||
# We make a future instead of returning the kernel here so that
|
||||
# kernels that are not statically launchable (i.e. cache miss)
|
||||
# can launch a worker without waiting on the blocking step of
|
||||
# StaticAutotunerFuture.result().
|
||||
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
|
||||
result.kernel
|
||||
)
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
|
||||
kernel_names.append(result.kernel_name)
|
||||
return kernel_names
|
||||
|
||||
@classmethod
|
||||
def collect(
|
||||
cls,
|
||||
) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
|
||||
) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]:
|
||||
"""
|
||||
This is the main function called when a cache write happens. This function
|
||||
converts all the previously remembered kernels into bundled format so that
|
||||
it can be written into a cache entry.
|
||||
This function also finalizes the current bundle.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
if not TritonBundler.is_enabled():
|
||||
cls.end_compile()
|
||||
set_feature_use("triton_bundling", False)
|
||||
return [], None
|
||||
return TritonBundle([], []), None
|
||||
set_feature_use("triton_bundling", True)
|
||||
|
||||
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
|
||||
|
|
@ -199,14 +302,21 @@ class TritonBundler:
|
|||
artifacts,
|
||||
)
|
||||
)
|
||||
if config.use_static_cuda_launcher:
|
||||
static_autotuners, static_kernel_names = (
|
||||
cls.collect_static_autotuners()
|
||||
)
|
||||
else:
|
||||
static_autotuners = []
|
||||
static_kernel_names = []
|
||||
cls.end_compile()
|
||||
return result, TritonBundlerMetadata(kernel_names)
|
||||
return [], None
|
||||
return TritonBundle(result, static_autotuners), TritonBundlerMetadata(
|
||||
kernel_names, static_kernel_names
|
||||
)
|
||||
return TritonBundle([], []), None
|
||||
|
||||
@staticmethod
|
||||
def read_and_emit(
|
||||
bundle: list[TritonKernelArtifacts],
|
||||
) -> Optional[TritonBundlerMetadata]:
|
||||
def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]:
|
||||
"""
|
||||
This is the main function called when a cache read happens. This function
|
||||
converts the bundled format back into individual files and writes them
|
||||
|
|
@ -219,6 +329,8 @@ class TritonBundler:
|
|||
Exclusive access means that no other process should be writing to
|
||||
or reading from the target directory.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
if not TritonBundler.is_enabled():
|
||||
return None
|
||||
|
||||
|
|
@ -227,7 +339,7 @@ class TritonBundler:
|
|||
):
|
||||
kernel_names: list[str] = []
|
||||
|
||||
for artifacts in bundle:
|
||||
for artifacts in bundle.kernel_artifacts:
|
||||
basedir = triton_cache_dir(artifacts.device)
|
||||
directory = os.path.join(basedir, artifacts.kernel_hash)
|
||||
|
||||
|
|
@ -272,4 +384,10 @@ class TritonBundler:
|
|||
# Atomic on POSIX systems
|
||||
os.replace(tmp_dir, directory)
|
||||
|
||||
return TritonBundlerMetadata(kernel_names)
|
||||
if config.use_static_cuda_launcher:
|
||||
static_kernel_names = TritonBundler.load_autotuners(
|
||||
bundle.static_autotuners
|
||||
)
|
||||
else:
|
||||
static_kernel_names = []
|
||||
return TritonBundlerMetadata(kernel_names, static_kernel_names)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user