[async-compile] add progressive compile mode (#157305)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157305
Approved by: https://github.com/aorenste
This commit is contained in:
bobrenjc93 2025-07-03 13:05:29 -07:00 committed by PyTorch MergeBot
parent 386bc9e2e9
commit d58ed04d89
3 changed files with 306 additions and 8 deletions

View File

@ -9,6 +9,7 @@ import importlib
import os
import sys
import time
import unittest
from unittest.mock import patch
import torch
@ -16,7 +17,14 @@ import torch.library
from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
IS_BIG_GPU,
requires_gpu,
requires_triton,
RUN_CPU,
RUN_GPU,
)
# Make the helper files in test/ importable
@ -75,6 +83,79 @@ class TestSubprocess(TestCase):
TestCase.tearDown(self)
torch._dynamo.reset()
@requires_gpu()
@requires_triton()
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_progressive(self):
from triton.testing import do_bench
from torch._inductor.compile_fx_async import _ProgressiveFxCompile
torch._inductor.compile_fx.fx_compile_progressive = True
x = torch.randn(1152, 1024, device=GPU_TYPE, dtype=torch.bfloat16)
y = torch.randn(1024, 1024, device=GPU_TYPE, dtype=torch.bfloat16)
@torch.compile(fullgraph=True, backend="inductor")
def optimized(x, y):
return (x @ y).relu()
_ProgressiveFxCompile._reset_stats()
with contextlib.ExitStack() as stack:
# When this bug is fixed, remove the cache disabling below
assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC
stack.enter_context(
torch._inductor.config.patch(
autotune_local_cache=False, fx_graph_cache=False
)
)
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
# How long to wait (in seconds) before giving up.
TIMEOUT = 300
# If non-None then how often (in seconds) to print a TICK message.
TICK_REPORT = None
start = time.time()
last_report = start
while _ProgressiveFxCompile._stat_optimized_runs < 4:
time.sleep(0.25)
optimized(x, y)
now = time.time()
if TICK_REPORT is not None and (now - last_report > TICK_REPORT):
print(f"*** TICK {int(now - start)}")
last_report = now
if now - start > TIMEOUT:
raise RuntimeError(
"Test timed out before producing a progressively optimized compiled artifact."
)
self.assertEqual(_ProgressiveFxCompile._stat_optimized_runs, 4)
self.assertGreater(_ProgressiveFxCompile._stat_fast_runs, 0)
self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_started, 1)
self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_finished, 1)
torch._inductor.compile_fx.fx_compile_progressive = False
@torch.compile(fullgraph=True, backend="inductor")
def baseline(x, y):
return (x @ y).relu()
# Warmup
baseline(x, y)
self.assertGreater(
do_bench(lambda: baseline(x, y)), do_bench(lambda: optimized(x, y))
)
@patch("torch._inductor.compile_fx.fx_compile_async", True)
def test_async(self):
# Test that async+subprocess works.
@ -90,7 +171,7 @@ class TestSubprocess(TestCase):
_AsyncFxCompile._reset_stats()
with contextlib.ExitStack() as stack:
# TODO: Turn off local caches - they don't play nice w/ async currently.
assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC
stack.enter_context(
torch._inductor.config.patch(
autotune_local_cache=False, fx_graph_cache=False

View File

@ -14,6 +14,7 @@ import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AbstractContextManager
from dataclasses import dataclass
from inspect import currentframe
from itertools import count
from operator import attrgetter
@ -164,21 +165,32 @@ class FxCompileMode(enum.Enum):
SUBPROCESS = 2
# Return compile mode and use_async flag
def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]:
@dataclass
class FxCompileConfig:
mode: FxCompileMode
use_async: bool
use_progressive: bool
def _fx_compile_mode_default() -> FxCompileConfig:
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
value = os.environ.get(name)
if value is None:
return FxCompileMode.NORMAL, False
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
use_async = False
use_progressive = False
if value.lower().startswith("progressive+"):
use_progressive = True
value = value[12:]
if value.lower().startswith("async+"):
use_async = True
value = value[6:]
try:
value = value.upper()
return FxCompileMode[value], use_async
return FxCompileConfig(FxCompileMode[value], use_async, use_progressive)
except KeyError:
import logging
@ -191,10 +203,20 @@ def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]:
)
# Remove from the environment so subprocesses don't ALSO complain.
os.environ.pop(name)
return FxCompileMode.NORMAL, False
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
fx_compile_mode, fx_compile_async = _fx_compile_mode_default()
def _get_progression_configs() -> list[dict[str, Any]]:
# TODO make this configurable
return [
{"max_autotune": True},
]
_fx_compile_config = _fx_compile_mode_default()
fx_compile_mode = _fx_compile_config.mode
fx_compile_async = _fx_compile_config.use_async
fx_compile_progressive = _fx_compile_config.use_progressive
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
@ -1576,6 +1598,21 @@ def fx_codegen_and_compile(
)
scheme = _AsyncFxCompile(scheme)
if fx_compile_progressive:
from .compile_fx_async import _ProgressiveFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile
assert isinstance(scheme, _OutOfProcessFxCompile), (
"progressive is only valid with an out-of-process compile mode"
)
progression_configs = _get_progression_configs()
# Use in-process compile for the fast version
fast_scheme = _InProcessFxCompile()
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)

View File

@ -11,6 +11,10 @@ from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
# When async compile works with cache, remove the disabling below
BUG_CACHES_DONT_WORK_WITH_ASYNC = True
if TYPE_CHECKING:
from collections.abc import Sequence
from concurrent.futures import Future
@ -179,3 +183,179 @@ class _AsyncFxCompile(FxCompile):
return output.graph
return _AsyncOutputCode(eager_output_code, f, callback)
# _ProgressiveOutputCode handles running a fast compile first, then hot-swapping
# to a more optimized version when the expensive compile finishes.
@final
class _ProgressiveOutputCode(OutputCode):
_fast_output_code: Optional[OutputCode]
_optimized_output_code: Optional[OutputCode]
_progression_futures: list[Optional[Future[_WireProtocolPickledOutput]]]
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
_post_compile_data: Optional[_PostCompileData] = None
_current_progression_index: int
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
# is more common let's default to that and we'll convert if necessary.
_boxed_call: bool = True
def __init__(
self,
# Fast compile that runs faster than the progressive compiles
fast_output_code: OutputCode,
# Futures for the progressive optimized compiles
progression_futures: list[Future[_WireProtocolPickledOutput]],
# Callback to convert the optimized result to OutputCode
callback: Callable[[_WireProtocolPickledOutput], OutputCode],
) -> None:
self._fast_output_code = fast_output_code
self._optimized_output_code = None
self._progression_futures = list(progression_futures)
self._callback = callback
self._current_progression_index = -1
@override
def __call__(self, args: Sequence[Any]) -> Any:
# Check if any newer progression stage is ready and switch to it
self._check_and_switch_progression()
if self._optimized_output_code is not None:
_ProgressiveFxCompile._stat_optimized_runs += 1
output_code = self._optimized_output_code
else:
_ProgressiveFxCompile._stat_fast_runs += 1
assert self._fast_output_code is not None
output_code = self._fast_output_code
boxed_call = getattr(output_code, "_boxed_call", False)
if boxed_call:
res = output_code.__call__(args)
else:
res = output_code.__call__(*args)
return res
def _check_and_switch_progression(self) -> None:
# Check if any newer progression stage is ready (in order from latest to earliest)
for i in range(
len(self._progression_futures) - 1, self._current_progression_index, -1
):
future = self._progression_futures[i]
if self._post_compile_data and future and future.done():
self._switch_to_progression_stage(i)
break
def _switch_to_progression_stage(self, stage_index: int) -> None:
future = self._progression_futures[stage_index]
assert future is not None
optimized_output_code = self._callback(future.result())
if pcd := self._post_compile_data:
# Only clear post_compile_data if this is the final progression stage
if stage_index == len(self._progression_futures) - 1:
self._post_compile_data = None
optimized_output_code.post_compile(
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
)
self._optimized_output_code = optimized_output_code
self._fast_output_code = None
self._current_progression_index = stage_index
# Clear earlier progression futures to free memory
for i in range(stage_index):
self._progression_futures[i] = None
@override
def post_compile(
self,
example_inputs: Sequence[InputType],
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
assert self._fast_output_code is not None
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
# Store for later when optimized version is ready
self._post_compile_data = _PostCompileData(
example_inputs, constants, graph_kwargs
)
# _ProgressiveFxCompile runs a fast compile immediately, then kicks off
# progressive compiles in the background and hot-swaps when they're ready.
@final
class _ProgressiveFxCompile(FxCompile):
_fast_compile: FxCompile
_optimized_compile: _OutOfProcessFxCompile
_progression_configs: list[dict[str, Any]]
# Debugging stats
_stat_bg_started: int = 0
_stat_bg_finished: int = 0
_stat_fast_runs: int = 0
_stat_optimized_runs: int = 0
def __init__(
self,
fast_compile: FxCompile,
optimized_compile: _OutOfProcessFxCompile,
progression_configs: list[dict[str, Any]],
) -> None:
self._fast_compile = fast_compile
self._optimized_compile = optimized_compile
self._progression_configs = progression_configs
@classmethod
def _reset_stats(cls) -> None:
cls._stat_bg_started = 0
cls._stat_bg_finished = 0
cls._stat_fast_runs = 0
cls._stat_optimized_runs = 0
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
import torch._inductor.config as inductor_config
progression_futures: list[Future[_WireProtocolPickledOutput]] = []
for config in self._progression_configs:
with inductor_config.patch(config):
_ProgressiveFxCompile._stat_bg_started += 1
# Start the progressive compiles in the background
serialized = self._optimized_compile.serialize_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
if not serialized:
continue
inputs, constants = serialized
future = self._optimized_compile._send_to_child_async(inputs)
progression_futures.append(future)
fast_output_code = self._fast_compile.codegen_and_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
if not progression_futures:
# All async compile attempts failed - just return the fast version
return fast_output_code
# Callback to handle the optimized result.
# This callback may be called multiple times, once for each progressive level completed,
# but may be skipped if a level either never completes or if a more optimal level
# completes before a less optimal one is switched to.
def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode:
_ProgressiveFxCompile._stat_bg_finished += 1
output = pickled_output.deserialize(constants)
self._optimized_compile._postprocess(output)
return output.graph
return _ProgressiveOutputCode(fast_output_code, progression_futures, callback)