mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use process pool for precompilation of triton templates (#142450)
Perf results: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2003%20Dec%202024%2022%3A57%3A51%20GMT&stopTime=Tue%2C%2010%20Dec%202024%2022%3A57%3A51%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/740/head&lCommit=b925256c29ec43e1933e4ede94b16d1f404b595f&rBranch=gh/eellison/740/base&rCommit=a161d6362f7d9db773322d2ce2a3a70aabbecf4b Training: <img width="793" alt="image" src="https://github.com/user-attachments/assets/75f5bc0d-8005-4213-ae88-0b94fb187dfc" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/142450 Approved by: https://github.com/jansel
This commit is contained in:
parent
c06b5048ba
commit
e890d67543
|
|
@ -168,7 +168,7 @@ class AsyncCompile:
|
||||||
return task()
|
return task()
|
||||||
return cls.pool().submit(task)
|
return cls.pool().submit(task)
|
||||||
|
|
||||||
def _use_process_pool(self):
|
def use_process_pool(self):
|
||||||
return (
|
return (
|
||||||
get_compile_threads() > 1
|
get_compile_threads() > 1
|
||||||
and self.process_pool().ready_future.done() # type: ignore[attr-defined]
|
and self.process_pool().ready_future.done() # type: ignore[attr-defined]
|
||||||
|
|
@ -185,7 +185,7 @@ class AsyncCompile:
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||||
if self._use_process_pool():
|
if self.use_process_pool():
|
||||||
set_feature_use(
|
set_feature_use(
|
||||||
"pytorch/inductor:enable_parallel_compile_version (post_warmup)", True
|
"pytorch/inductor:enable_parallel_compile_version (post_warmup)", True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,8 @@ PRINT_AUTOTUNE = True
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import concurrent
|
||||||
|
|
||||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1752,16 +1754,35 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
|
|
||||||
def precompile_with_captured_stdout(choice):
|
def precompile_with_captured_stdout(choice):
|
||||||
with restore_stdout_stderr(initial_stdout, initial_stderr):
|
with restore_stdout_stderr(initial_stdout, initial_stderr):
|
||||||
start_time = time.time()
|
|
||||||
choice.precompile()
|
choice.precompile()
|
||||||
return time.time() - start_time
|
|
||||||
|
def on_complete(future):
|
||||||
|
assert future in start_times
|
||||||
|
elapsed_times[future] = time.time() - start_times[future]
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=num_workers)
|
executor = ThreadPoolExecutor(max_workers=num_workers)
|
||||||
|
async_compile = torch._inductor.async_compile.AsyncCompile()
|
||||||
|
|
||||||
|
futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
|
||||||
|
start_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||||
|
elapsed_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||||
|
|
||||||
futures = {}
|
|
||||||
for c in choices:
|
for c in choices:
|
||||||
if hasattr(c, "precompile"):
|
if hasattr(c, "precompile"):
|
||||||
future = executor.submit(precompile_with_captured_stdout, c)
|
triton_cuda_choice = isinstance(
|
||||||
|
c, TritonTemplateCaller
|
||||||
|
) and isinstance(c.bmreq, TritonGPUBenchmarkRequest)
|
||||||
|
if triton_cuda_choice and async_compile.use_process_pool():
|
||||||
|
with open(c.bmreq.module_path) as file:
|
||||||
|
source_code = file.read()
|
||||||
|
future = async_compile.triton(
|
||||||
|
kernel_name=c.bmreq.kernel_name, source_code=source_code
|
||||||
|
).future
|
||||||
|
else:
|
||||||
|
future = executor.submit(precompile_with_captured_stdout, c)
|
||||||
|
|
||||||
|
start_times[future] = time.time()
|
||||||
|
future.add_done_callback(on_complete)
|
||||||
futures[future] = c
|
futures[future] = c
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
|
|
@ -1780,7 +1801,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
log.info(
|
log.info(
|
||||||
"Precompiling benchmark choice %s took %.02fs",
|
"Precompiling benchmark choice %s took %.02fs",
|
||||||
futures[future],
|
futures[future],
|
||||||
future.result(),
|
elapsed_times[future],
|
||||||
)
|
)
|
||||||
|
|
||||||
executor.shutdown(wait=True)
|
executor.shutdown(wait=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user