mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[async-compile] add progressive compile mode (#157305)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157305 Approved by: https://github.com/aorenste
This commit is contained in:
parent
386bc9e2e9
commit
d58ed04d89
|
|
@ -9,6 +9,7 @@ import importlib
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -16,7 +17,14 @@ import torch.library
|
|||
from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
IS_BIG_GPU,
|
||||
requires_gpu,
|
||||
requires_triton,
|
||||
RUN_CPU,
|
||||
RUN_GPU,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
|
|
@ -75,6 +83,79 @@ class TestSubprocess(TestCase):
|
|||
TestCase.tearDown(self)
|
||||
torch._dynamo.reset()
|
||||
|
||||
@requires_gpu()
|
||||
@requires_triton()
|
||||
@unittest.skipIf(
|
||||
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
||||
)
|
||||
def test_progressive(self):
|
||||
from triton.testing import do_bench
|
||||
|
||||
from torch._inductor.compile_fx_async import _ProgressiveFxCompile
|
||||
|
||||
torch._inductor.compile_fx.fx_compile_progressive = True
|
||||
|
||||
x = torch.randn(1152, 1024, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
y = torch.randn(1024, 1024, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
|
||||
@torch.compile(fullgraph=True, backend="inductor")
|
||||
def optimized(x, y):
|
||||
return (x @ y).relu()
|
||||
|
||||
_ProgressiveFxCompile._reset_stats()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
# When this bug is fixed, remove the cache disabling below
|
||||
assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(
|
||||
autotune_local_cache=False, fx_graph_cache=False
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
|
||||
# How long to wait (in seconds) before giving up.
|
||||
TIMEOUT = 300
|
||||
# If non-None then how often (in seconds) to print a TICK message.
|
||||
TICK_REPORT = None
|
||||
|
||||
start = time.time()
|
||||
last_report = start
|
||||
while _ProgressiveFxCompile._stat_optimized_runs < 4:
|
||||
time.sleep(0.25)
|
||||
|
||||
optimized(x, y)
|
||||
|
||||
now = time.time()
|
||||
if TICK_REPORT is not None and (now - last_report > TICK_REPORT):
|
||||
print(f"*** TICK {int(now - start)}")
|
||||
last_report = now
|
||||
|
||||
if now - start > TIMEOUT:
|
||||
raise RuntimeError(
|
||||
"Test timed out before producing a progressively optimized compiled artifact."
|
||||
)
|
||||
|
||||
self.assertEqual(_ProgressiveFxCompile._stat_optimized_runs, 4)
|
||||
self.assertGreater(_ProgressiveFxCompile._stat_fast_runs, 0)
|
||||
self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_started, 1)
|
||||
self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_finished, 1)
|
||||
|
||||
torch._inductor.compile_fx.fx_compile_progressive = False
|
||||
|
||||
@torch.compile(fullgraph=True, backend="inductor")
|
||||
def baseline(x, y):
|
||||
return (x @ y).relu()
|
||||
|
||||
# Warmup
|
||||
baseline(x, y)
|
||||
|
||||
self.assertGreater(
|
||||
do_bench(lambda: baseline(x, y)), do_bench(lambda: optimized(x, y))
|
||||
)
|
||||
|
||||
@patch("torch._inductor.compile_fx.fx_compile_async", True)
|
||||
def test_async(self):
|
||||
# Test that async+subprocess works.
|
||||
|
|
@ -90,7 +171,7 @@ class TestSubprocess(TestCase):
|
|||
_AsyncFxCompile._reset_stats()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
# TODO: Turn off local caches - they don't play nice w/ async currently.
|
||||
assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(
|
||||
autotune_local_cache=False, fx_graph_cache=False
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import warnings
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from inspect import currentframe
|
||||
from itertools import count
|
||||
from operator import attrgetter
|
||||
|
|
@ -164,21 +165,32 @@ class FxCompileMode(enum.Enum):
|
|||
SUBPROCESS = 2
|
||||
|
||||
|
||||
# Return compile mode and use_async flag
|
||||
def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]:
|
||||
@dataclass
|
||||
class FxCompileConfig:
|
||||
mode: FxCompileMode
|
||||
use_async: bool
|
||||
use_progressive: bool
|
||||
|
||||
|
||||
def _fx_compile_mode_default() -> FxCompileConfig:
|
||||
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
return FxCompileMode.NORMAL, False
|
||||
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
|
||||
|
||||
use_async = False
|
||||
use_progressive = False
|
||||
|
||||
if value.lower().startswith("progressive+"):
|
||||
use_progressive = True
|
||||
value = value[12:]
|
||||
if value.lower().startswith("async+"):
|
||||
use_async = True
|
||||
value = value[6:]
|
||||
|
||||
try:
|
||||
value = value.upper()
|
||||
return FxCompileMode[value], use_async
|
||||
return FxCompileConfig(FxCompileMode[value], use_async, use_progressive)
|
||||
except KeyError:
|
||||
import logging
|
||||
|
||||
|
|
@ -191,10 +203,20 @@ def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]:
|
|||
)
|
||||
# Remove from the environment so subprocesses don't ALSO complain.
|
||||
os.environ.pop(name)
|
||||
return FxCompileMode.NORMAL, False
|
||||
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
|
||||
|
||||
|
||||
fx_compile_mode, fx_compile_async = _fx_compile_mode_default()
|
||||
def _get_progression_configs() -> list[dict[str, Any]]:
|
||||
# TODO make this configurable
|
||||
return [
|
||||
{"max_autotune": True},
|
||||
]
|
||||
|
||||
|
||||
_fx_compile_config = _fx_compile_mode_default()
|
||||
fx_compile_mode = _fx_compile_config.mode
|
||||
fx_compile_async = _fx_compile_config.use_async
|
||||
fx_compile_progressive = _fx_compile_config.use_progressive
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
|
|
@ -1576,6 +1598,21 @@ def fx_codegen_and_compile(
|
|||
)
|
||||
scheme = _AsyncFxCompile(scheme)
|
||||
|
||||
if fx_compile_progressive:
|
||||
from .compile_fx_async import _ProgressiveFxCompile
|
||||
from .compile_fx_ext import _OutOfProcessFxCompile
|
||||
|
||||
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
||||
"progressive is only valid with an out-of-process compile mode"
|
||||
)
|
||||
|
||||
progression_configs = _get_progression_configs()
|
||||
|
||||
# Use in-process compile for the fast version
|
||||
fast_scheme = _InProcessFxCompile()
|
||||
|
||||
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
|
||||
|
||||
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile
|
|||
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
|
||||
|
||||
|
||||
# When async compile works with cache, remove the disabling below
|
||||
BUG_CACHES_DONT_WORK_WITH_ASYNC = True
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import Future
|
||||
|
|
@ -179,3 +183,179 @@ class _AsyncFxCompile(FxCompile):
|
|||
return output.graph
|
||||
|
||||
return _AsyncOutputCode(eager_output_code, f, callback)
|
||||
|
||||
|
||||
# _ProgressiveOutputCode handles running a fast compile first, then hot-swapping
|
||||
# to a more optimized version when the expensive compile finishes.
|
||||
@final
|
||||
class _ProgressiveOutputCode(OutputCode):
|
||||
_fast_output_code: Optional[OutputCode]
|
||||
_optimized_output_code: Optional[OutputCode]
|
||||
_progression_futures: list[Optional[Future[_WireProtocolPickledOutput]]]
|
||||
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
|
||||
_post_compile_data: Optional[_PostCompileData] = None
|
||||
_current_progression_index: int
|
||||
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/
|
||||
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
|
||||
# is more common let's default to that and we'll convert if necessary.
|
||||
_boxed_call: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Fast compile that runs faster than the progressive compiles
|
||||
fast_output_code: OutputCode,
|
||||
# Futures for the progressive optimized compiles
|
||||
progression_futures: list[Future[_WireProtocolPickledOutput]],
|
||||
# Callback to convert the optimized result to OutputCode
|
||||
callback: Callable[[_WireProtocolPickledOutput], OutputCode],
|
||||
) -> None:
|
||||
self._fast_output_code = fast_output_code
|
||||
self._optimized_output_code = None
|
||||
self._progression_futures = list(progression_futures)
|
||||
self._callback = callback
|
||||
self._current_progression_index = -1
|
||||
|
||||
@override
|
||||
def __call__(self, args: Sequence[Any]) -> Any:
|
||||
# Check if any newer progression stage is ready and switch to it
|
||||
self._check_and_switch_progression()
|
||||
|
||||
if self._optimized_output_code is not None:
|
||||
_ProgressiveFxCompile._stat_optimized_runs += 1
|
||||
output_code = self._optimized_output_code
|
||||
else:
|
||||
_ProgressiveFxCompile._stat_fast_runs += 1
|
||||
assert self._fast_output_code is not None
|
||||
output_code = self._fast_output_code
|
||||
|
||||
boxed_call = getattr(output_code, "_boxed_call", False)
|
||||
if boxed_call:
|
||||
res = output_code.__call__(args)
|
||||
else:
|
||||
res = output_code.__call__(*args)
|
||||
return res
|
||||
|
||||
def _check_and_switch_progression(self) -> None:
|
||||
# Check if any newer progression stage is ready (in order from latest to earliest)
|
||||
for i in range(
|
||||
len(self._progression_futures) - 1, self._current_progression_index, -1
|
||||
):
|
||||
future = self._progression_futures[i]
|
||||
if self._post_compile_data and future and future.done():
|
||||
self._switch_to_progression_stage(i)
|
||||
break
|
||||
|
||||
def _switch_to_progression_stage(self, stage_index: int) -> None:
|
||||
future = self._progression_futures[stage_index]
|
||||
assert future is not None
|
||||
optimized_output_code = self._callback(future.result())
|
||||
|
||||
if pcd := self._post_compile_data:
|
||||
# Only clear post_compile_data if this is the final progression stage
|
||||
if stage_index == len(self._progression_futures) - 1:
|
||||
self._post_compile_data = None
|
||||
optimized_output_code.post_compile(
|
||||
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
|
||||
)
|
||||
|
||||
self._optimized_output_code = optimized_output_code
|
||||
self._fast_output_code = None
|
||||
self._current_progression_index = stage_index
|
||||
|
||||
# Clear earlier progression futures to free memory
|
||||
for i in range(stage_index):
|
||||
self._progression_futures[i] = None
|
||||
|
||||
@override
|
||||
def post_compile(
|
||||
self,
|
||||
example_inputs: Sequence[InputType],
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
assert self._fast_output_code is not None
|
||||
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
|
||||
# Store for later when optimized version is ready
|
||||
self._post_compile_data = _PostCompileData(
|
||||
example_inputs, constants, graph_kwargs
|
||||
)
|
||||
|
||||
|
||||
# _ProgressiveFxCompile runs a fast compile immediately, then kicks off
|
||||
# progressive compiles in the background and hot-swaps when they're ready.
|
||||
@final
|
||||
class _ProgressiveFxCompile(FxCompile):
|
||||
_fast_compile: FxCompile
|
||||
_optimized_compile: _OutOfProcessFxCompile
|
||||
_progression_configs: list[dict[str, Any]]
|
||||
|
||||
# Debugging stats
|
||||
_stat_bg_started: int = 0
|
||||
_stat_bg_finished: int = 0
|
||||
_stat_fast_runs: int = 0
|
||||
_stat_optimized_runs: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fast_compile: FxCompile,
|
||||
optimized_compile: _OutOfProcessFxCompile,
|
||||
progression_configs: list[dict[str, Any]],
|
||||
) -> None:
|
||||
self._fast_compile = fast_compile
|
||||
self._optimized_compile = optimized_compile
|
||||
self._progression_configs = progression_configs
|
||||
|
||||
@classmethod
|
||||
def _reset_stats(cls) -> None:
|
||||
cls._stat_bg_started = 0
|
||||
cls._stat_bg_finished = 0
|
||||
cls._stat_fast_runs = 0
|
||||
cls._stat_optimized_runs = 0
|
||||
|
||||
@override
|
||||
def codegen_and_compile(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
inputs_to_check: Sequence[int],
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode:
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
progression_futures: list[Future[_WireProtocolPickledOutput]] = []
|
||||
|
||||
for config in self._progression_configs:
|
||||
with inductor_config.patch(config):
|
||||
_ProgressiveFxCompile._stat_bg_started += 1
|
||||
|
||||
# Start the progressive compiles in the background
|
||||
serialized = self._optimized_compile.serialize_compile(
|
||||
gm, example_inputs, inputs_to_check, graph_kwargs
|
||||
)
|
||||
|
||||
if not serialized:
|
||||
continue
|
||||
|
||||
inputs, constants = serialized
|
||||
future = self._optimized_compile._send_to_child_async(inputs)
|
||||
progression_futures.append(future)
|
||||
|
||||
fast_output_code = self._fast_compile.codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, graph_kwargs
|
||||
)
|
||||
|
||||
if not progression_futures:
|
||||
# All async compile attempts failed - just return the fast version
|
||||
return fast_output_code
|
||||
|
||||
# Callback to handle the optimized result.
|
||||
# This callback may be called multiple times, once for each progressive level completed,
|
||||
# but may be skipped if a level either never completes or if a more optimal level
|
||||
# completes before a less optimal one is switched to.
|
||||
def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode:
|
||||
_ProgressiveFxCompile._stat_bg_finished += 1
|
||||
output = pickled_output.deserialize(constants)
|
||||
self._optimized_compile._postprocess(output)
|
||||
return output.graph
|
||||
|
||||
return _ProgressiveOutputCode(fast_output_code, progression_futures, callback)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user