mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
When we attempt prologue or epilogue fusion with a TritonTemplate, we benchmark it at compile time in order to determine profitability. This avoids slowdowns/register spilling, and allows us to pick fusion when a base triton template is slower than cublas but faster when considering an epilogue. However, that fused benchmarking does not do the same async compilation as we do for the base TritonTemplate. The Base TritonTemplate is async compiled during lowering, then later waited on and benchmarked. This PR extends a similar process to benchmarking fused TritonTemplates in the scheduler. We keep a list of pending fusions which have async compilations. And we resolve any pending fusions a node is in prior to attempting to fuse it with any other node. Initially, I saw some slowdowns with this because we kick off async compilations of identical fusions in parallel. To address this I added source code caching at the `async_compile` level (we also already cache benchmark runs, but that would not happen in parallel). Compilation speedups: <img width="717" alt="image" src="https://github.com/user-attachments/assets/8e8f7d6c-7824-4210-83f9-a2a0f6db5ac9" /> This also should let us be a bit more aggressive with either configs, or benchmarking other fusions which are hard to determine profitability of. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143408 Approved by: https://github.com/jansel, https://github.com/shunting314
378 lines
13 KiB
Python
378 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import functools
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|
from concurrent.futures.process import BrokenProcessPool
|
|
from functools import partial
|
|
from time import time
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._dynamo.device_interface import get_registered_device_interfaces
|
|
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
|
from torch._inductor import config
|
|
from torch._inductor.codecache import (
|
|
CodeCacheFuture,
|
|
CppCodeCache,
|
|
CppPythonBindingsCodeCache,
|
|
CUDACodeCache,
|
|
HalideCodeCache,
|
|
LambdaFuture,
|
|
ROCmCodeCache,
|
|
TritonCodeCache,
|
|
TritonFuture,
|
|
)
|
|
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
|
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
|
from torch._inductor.runtime.compile_tasks import (
|
|
_set_triton_ptxas_path,
|
|
_worker_compile_triton,
|
|
)
|
|
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
|
from torch.hub import _Faketqdm, tqdm
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._triton import has_triton_package
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.runtime.hints import HalideMeta
|
|
|
|
# timing metrics for time spent in the compilation
|
|
_cumulative_compile_time = 0.0
|
|
_t0: Optional[float] = None
|
|
|
|
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def pre_fork_setup():
|
|
"""
|
|
Setup that must be done prior to forking with a process pool.
|
|
"""
|
|
# ensure properties have been calculated before processes
|
|
# are forked
|
|
caching_device_properties()
|
|
|
|
# Computing the triton key can be slow. If we call it before fork,
|
|
# it will be cached for the forked subprocesses.
|
|
try:
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
triton_key()
|
|
except ImportError:
|
|
# Triton might not be installed or might be an old version.
|
|
pass
|
|
|
|
|
|
def caching_device_properties():
|
|
for _, device_interface in get_registered_device_interfaces():
|
|
if device_interface.is_available():
|
|
device_interface.Worker.get_device_properties()
|
|
|
|
|
|
def _compile_start() -> None:
|
|
global _t0
|
|
if _t0 is None:
|
|
_t0 = time()
|
|
|
|
|
|
def _compile_end() -> None:
|
|
global _cumulative_compile_time, _t0
|
|
if _t0 is not None:
|
|
t1 = time()
|
|
_cumulative_compile_time += t1 - _t0
|
|
_t0 = None
|
|
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
|
|
|
|
|
|
_IS_WINDOWS = sys.platform == "win32"
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Used to keep track of all process pools invoked so far.
|
|
_pool_set = OrderedSet[AnyPool]()
|
|
|
|
|
|
def shutdown_compile_workers() -> None:
|
|
"""Shut down all outstanding compile-worker pools."""
|
|
for pool in _pool_set:
|
|
pool.shutdown()
|
|
after_fork()
|
|
|
|
|
|
def after_fork():
|
|
"""Reset pools to initial state without shutting them down"""
|
|
_pool_set.clear()
|
|
AsyncCompile.process_pool.cache_clear()
|
|
|
|
|
|
try:
|
|
os.register_at_fork(after_in_child=after_fork)
|
|
except AttributeError:
|
|
pass # register_at_fork does not exists on windows
|
|
|
|
|
|
def get_compile_threads() -> int:
|
|
"""
|
|
Temporary for internal rollout. Assign config.compile_threads lazily and return it.
|
|
TODO: remove after rollout.
|
|
"""
|
|
if config.compile_threads is None:
|
|
config.compile_threads = config.decide_compile_threads()
|
|
return config.compile_threads
|
|
|
|
|
|
@clear_on_fresh_inductor_cache
|
|
@functools.lru_cache(None)
|
|
def get_future_cache():
|
|
return {}
|
|
|
|
|
|
class AsyncCompile:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def pool() -> ThreadPoolExecutor:
|
|
assert get_compile_threads() > 1
|
|
return ThreadPoolExecutor(get_compile_threads())
|
|
|
|
@staticmethod
|
|
def _get_ready():
|
|
"""No-op function to help mark when the subprocess pool is ready."""
|
|
return "ready"
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def process_pool() -> AnyPool:
|
|
assert get_compile_threads() > 1
|
|
log.info(
|
|
"Creating '%s' pool with %d workers",
|
|
config.worker_start_method,
|
|
get_compile_threads(),
|
|
)
|
|
|
|
pool: AnyPool
|
|
if config.worker_start_method == "subprocess":
|
|
# Wrapper around ProcessPoolExecutor forks in a new process we control
|
|
pool = SubprocPool(get_compile_threads())
|
|
else:
|
|
if config.worker_start_method == "spawn":
|
|
# Avoid creating pools in the spawned subprocs themselves:
|
|
os.environ["TORCH_WARM_POOL"] = "0"
|
|
pre_fork_setup()
|
|
ctx = multiprocessing.get_context(config.worker_start_method)
|
|
pool = ProcessPoolExecutor(
|
|
get_compile_threads(),
|
|
mp_context=ctx,
|
|
initializer=partial(_async_compile_initializer, os.getpid()),
|
|
)
|
|
# when this pool is created in a subprocess object, the normal exit handler
|
|
# doesn't run, and we need to register our own handler.
|
|
# exitpriority has to be high, because another one of the finalizers will
|
|
# kill the worker thread that sends the shutdown message to the workers...
|
|
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
|
|
|
# Set an attribute we can check to see if the pool is ready.
|
|
pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
|
|
_pool_set.add(pool)
|
|
return pool
|
|
|
|
@classmethod
|
|
def warm_pool(cls) -> None:
|
|
if get_compile_threads() <= 1:
|
|
return
|
|
_compile_start()
|
|
# Pool is initialized on first access
|
|
cls.process_pool()
|
|
_compile_end()
|
|
|
|
@classmethod
|
|
def submit(cls, task: Callable[..., Any]) -> Any:
|
|
if get_compile_threads() <= 1:
|
|
return task()
|
|
return cls.pool().submit(task)
|
|
|
|
def use_process_pool(self):
|
|
return (
|
|
get_compile_threads() > 1
|
|
and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
|
)
|
|
|
|
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
|
kernel_code_log.info("Triton Kernel:\n%s", source_code)
|
|
_compile_start()
|
|
_set_triton_ptxas_path()
|
|
|
|
if os.environ.get("TRITON_INTERPRET", "0") == "1":
|
|
return getattr(
|
|
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
|
)
|
|
|
|
kernel = TritonCodeCache.load(kernel_name, source_code)
|
|
if self.use_process_pool():
|
|
set_feature_use("parallel_compile_post_warmup", True)
|
|
# We want to support changing these env vars after (and while) the
|
|
# process pool is running, so pass them to the subprocess to reset.
|
|
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
|
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
|
|
|
future_cache = get_future_cache()
|
|
|
|
if future := future_cache.get(source_code, None):
|
|
counters["inductor"]["async_compile_cache_hit"] += 1
|
|
return future
|
|
|
|
counters["inductor"]["async_compile_cache_miss"] += 1
|
|
future = TritonFuture(
|
|
kernel,
|
|
self.process_pool().submit(
|
|
_worker_compile_triton,
|
|
kernel._reload_in_subproc,
|
|
extra_env,
|
|
),
|
|
)
|
|
future_cache[source_code] = future
|
|
return future
|
|
|
|
else:
|
|
set_feature_use("parallel_compile_post_warmup", False)
|
|
with dynamo_timed(
|
|
"async_compile.precompile",
|
|
log_pt2_compile_event=True,
|
|
dynamo_compile_column_us="triton_compile_time_us",
|
|
log_waitcounter=True,
|
|
):
|
|
kernel.precompile()
|
|
return kernel
|
|
|
|
def multi_kernel(self, *args, **kwargs) -> Any:
|
|
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
|
|
|
# no need to call this in parallel since the sub-kernels are already parallel tasks
|
|
return MultiKernelCall(*args, **kwargs)
|
|
|
|
def cpp(self, source_code: str):
|
|
kernel_code_log.info("CPP Kernel:\n%s", source_code)
|
|
if get_compile_threads() <= 1:
|
|
return CppCodeCache.load(source_code).kernel
|
|
else:
|
|
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
|
return LambdaFuture(lambda: get_result().kernel)
|
|
|
|
def cpp_pybinding(self, argtypes: list[str], source_code: str):
|
|
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
|
if get_compile_threads() <= 1:
|
|
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
|
else:
|
|
get_result = CppPythonBindingsCodeCache.load_pybinding_async(
|
|
argtypes, source_code, submit_fn=self.submit
|
|
)
|
|
return LambdaFuture(get_result)
|
|
|
|
def cuda(self, source_code, dst_file_ext, aot_compile=False):
|
|
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
|
|
|
|
def task():
|
|
if aot_compile:
|
|
# We rely on JITInductor to compile the CUDA code,
|
|
# so that we can load it into AOTInductor.
|
|
CUDACodeCache.compile(source_code, "o")
|
|
return CUDACodeCache.load(source_code, dst_file_ext)[0]
|
|
|
|
return self.submit(task)
|
|
|
|
def rocm(
|
|
self,
|
|
source_code,
|
|
dst_file_ext,
|
|
aot_compile=False,
|
|
):
|
|
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
|
|
|
|
def task():
|
|
if aot_compile:
|
|
_ = ROCmCodeCache.compile(source_code, dst_file_ext="o")
|
|
if config.rocm.generate_test_runner:
|
|
_ = ROCmCodeCache.compile(source_code, dst_file_ext="exe")
|
|
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
|
|
|
|
return self.submit(task)
|
|
|
|
def halide(self, meta: HalideMeta, source_code: str):
|
|
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
|
|
if get_compile_threads() <= 1:
|
|
return HalideCodeCache.generate_halide(meta, source_code)
|
|
else:
|
|
get_result = HalideCodeCache.generate_halide_async(
|
|
meta, source_code, submit_fn=self.submit
|
|
)
|
|
return LambdaFuture(get_result)
|
|
|
|
def wait(self, scope: dict[str, Any]) -> None:
|
|
with dynamo_timed(
|
|
"async_compile.wait",
|
|
log_pt2_compile_event=True,
|
|
dynamo_compile_column_us="triton_compile_time_us",
|
|
log_waitcounter=True,
|
|
):
|
|
num_kernels = len(
|
|
[
|
|
value
|
|
for key, value in scope.items()
|
|
if isinstance(value, (Future, CodeCacheFuture))
|
|
]
|
|
)
|
|
pbar = tqdm(
|
|
total=num_kernels,
|
|
desc="Inductor Compilation",
|
|
disable=config.disable_progress,
|
|
delay=0,
|
|
)
|
|
if get_compile_threads() > 1:
|
|
for key, result in scope.items():
|
|
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
|
|
pbar.set_postfix_str(key)
|
|
if isinstance(result, (Future, CodeCacheFuture)):
|
|
try:
|
|
scope[key] = result.result()
|
|
except BrokenProcessPool as e:
|
|
raise RuntimeError(
|
|
"A compilation subprocess exited unexpectedly. This "
|
|
"is likely due to a crash. To facilitate debugging, "
|
|
"you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
|
|
"to cause compilation to occur in the main process."
|
|
) from e
|
|
pbar.update(1)
|
|
|
|
_compile_end()
|
|
|
|
|
|
if (
|
|
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
|
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
|
# The subprocess pool is only used for the Triton backend
|
|
or not has_triton_package()
|
|
# Skip for fbcode. We have internal reports of usages inside multiprocessing
|
|
# pools that lead a multiplicative number of compile subprocesses.
|
|
or config.is_fbcode()
|
|
):
|
|
pass
|
|
else:
|
|
AsyncCompile.warm_pool()
|
|
|
|
# On exit give the workers a chance to clean themselves up. Without this the
|
|
# resource_tracker can complain about leaked semaphores coming from the
|
|
# ProcessPoolExecutor:
|
|
# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects
|
|
# to clean up at shutdown
|
|
atexit.register(shutdown_compile_workers)
|