eellison 2024-12-17 11:48:13 -08:00 committed by PyTorch MergeBot
parent c06b5048ba
commit e890d67543
2 changed files with 28 additions and 7 deletions

View File

@ -168,7 +168,7 @@ class AsyncCompile:
return task()
return cls.pool().submit(task)
def _use_process_pool(self):
def use_process_pool(self):
return (
get_compile_threads() > 1
and self.process_pool().ready_future.done() # type: ignore[attr-defined]
@ -185,7 +185,7 @@ class AsyncCompile:
)
kernel = TritonCodeCache.load(kernel_name, source_code)
if self._use_process_pool():
if self.use_process_pool():
set_feature_use(
"pytorch/inductor:enable_parallel_compile_version (post_warmup)", True
)

View File

@ -90,6 +90,8 @@ PRINT_AUTOTUNE = True
DEBUG = False
if TYPE_CHECKING:
import concurrent
from torch._inductor.codegen.simd import IterationRangesRoot
@ -1752,16 +1754,35 @@ class AlgorithmSelectorCache(PersistentCache):
def precompile_with_captured_stdout(choice):
with restore_stdout_stderr(initial_stdout, initial_stderr):
start_time = time.time()
choice.precompile()
return time.time() - start_time
def on_complete(future):
assert future in start_times
elapsed_times[future] = time.time() - start_times[future]
executor = ThreadPoolExecutor(max_workers=num_workers)
async_compile = torch._inductor.async_compile.AsyncCompile()
futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
start_times: Dict[concurrent.futures.Future[Any], float] = {}
elapsed_times: Dict[concurrent.futures.Future[Any], float] = {}
futures = {}
for c in choices:
if hasattr(c, "precompile"):
future = executor.submit(precompile_with_captured_stdout, c)
triton_cuda_choice = isinstance(
c, TritonTemplateCaller
) and isinstance(c.bmreq, TritonGPUBenchmarkRequest)
if triton_cuda_choice and async_compile.use_process_pool():
with open(c.bmreq.module_path) as file:
source_code = file.read()
future = async_compile.triton(
kernel_name=c.bmreq.kernel_name, source_code=source_code
).future
else:
future = executor.submit(precompile_with_captured_stdout, c)
start_times[future] = time.time()
future.add_done_callback(on_complete)
futures[future] = c
@functools.lru_cache(None)
@ -1780,7 +1801,7 @@ class AlgorithmSelectorCache(PersistentCache):
log.info(
"Precompiling benchmark choice %s took %.02fs",
futures[future],
future.result(),
elapsed_times[future],
)
executor.shutdown(wait=True)