mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use process pool for precompilation of triton templates (#142450)
Perf results: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2003%20Dec%202024%2022%3A57%3A51%20GMT&stopTime=Tue%2C%2010%20Dec%202024%2022%3A57%3A51%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/740/head&lCommit=b925256c29ec43e1933e4ede94b16d1f404b595f&rBranch=gh/eellison/740/base&rCommit=a161d6362f7d9db773322d2ce2a3a70aabbecf4b Training: <img width="793" alt="image" src="https://github.com/user-attachments/assets/75f5bc0d-8005-4213-ae88-0b94fb187dfc" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/142450 Approved by: https://github.com/jansel
This commit is contained in:
parent
c06b5048ba
commit
e890d67543
|
|
@ -168,7 +168,7 @@ class AsyncCompile:
|
|||
return task()
|
||||
return cls.pool().submit(task)
|
||||
|
||||
def _use_process_pool(self):
|
||||
def use_process_pool(self):
|
||||
return (
|
||||
get_compile_threads() > 1
|
||||
and self.process_pool().ready_future.done() # type: ignore[attr-defined]
|
||||
|
|
@ -185,7 +185,7 @@ class AsyncCompile:
|
|||
)
|
||||
|
||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||
if self._use_process_pool():
|
||||
if self.use_process_pool():
|
||||
set_feature_use(
|
||||
"pytorch/inductor:enable_parallel_compile_version (post_warmup)", True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ PRINT_AUTOTUNE = True
|
|||
DEBUG = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import concurrent
|
||||
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
|
||||
|
||||
|
|
@ -1752,16 +1754,35 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
|
||||
def precompile_with_captured_stdout(choice):
|
||||
with restore_stdout_stderr(initial_stdout, initial_stderr):
|
||||
start_time = time.time()
|
||||
choice.precompile()
|
||||
return time.time() - start_time
|
||||
|
||||
def on_complete(future):
|
||||
assert future in start_times
|
||||
elapsed_times[future] = time.time() - start_times[future]
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=num_workers)
|
||||
async_compile = torch._inductor.async_compile.AsyncCompile()
|
||||
|
||||
futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
|
||||
start_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||
elapsed_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||
|
||||
futures = {}
|
||||
for c in choices:
|
||||
if hasattr(c, "precompile"):
|
||||
future = executor.submit(precompile_with_captured_stdout, c)
|
||||
triton_cuda_choice = isinstance(
|
||||
c, TritonTemplateCaller
|
||||
) and isinstance(c.bmreq, TritonGPUBenchmarkRequest)
|
||||
if triton_cuda_choice and async_compile.use_process_pool():
|
||||
with open(c.bmreq.module_path) as file:
|
||||
source_code = file.read()
|
||||
future = async_compile.triton(
|
||||
kernel_name=c.bmreq.kernel_name, source_code=source_code
|
||||
).future
|
||||
else:
|
||||
future = executor.submit(precompile_with_captured_stdout, c)
|
||||
|
||||
start_times[future] = time.time()
|
||||
future.add_done_callback(on_complete)
|
||||
futures[future] = c
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
@ -1780,7 +1801,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
log.info(
|
||||
"Precompiling benchmark choice %s took %.02fs",
|
||||
futures[future],
|
||||
future.result(),
|
||||
elapsed_times[future],
|
||||
)
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user