diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 351750fe649..f742698f4ab 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -168,7 +168,7 @@ class AsyncCompile: return task() return cls.pool().submit(task) - def _use_process_pool(self): + def use_process_pool(self): return ( get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[attr-defined] @@ -185,7 +185,7 @@ class AsyncCompile: ) kernel = TritonCodeCache.load(kernel_name, source_code) - if self._use_process_pool(): + if self.use_process_pool(): set_feature_use( "pytorch/inductor:enable_parallel_compile_version (post_warmup)", True ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index e0e0c9afd6a..07bfd3a78dd 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -90,6 +90,8 @@ PRINT_AUTOTUNE = True DEBUG = False if TYPE_CHECKING: + import concurrent + from torch._inductor.codegen.simd import IterationRangesRoot @@ -1752,16 +1754,35 @@ class AlgorithmSelectorCache(PersistentCache): def precompile_with_captured_stdout(choice): with restore_stdout_stderr(initial_stdout, initial_stderr): - start_time = time.time() choice.precompile() - return time.time() - start_time + + def on_complete(future): + assert future in start_times + elapsed_times[future] = time.time() - start_times[future] executor = ThreadPoolExecutor(max_workers=num_workers) + async_compile = torch._inductor.async_compile.AsyncCompile() + + futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {} + start_times: Dict[concurrent.futures.Future[Any], float] = {} + elapsed_times: Dict[concurrent.futures.Future[Any], float] = {} - futures = {} for c in choices: if hasattr(c, "precompile"): - future = executor.submit(precompile_with_captured_stdout, c) + triton_cuda_choice = isinstance( + c, TritonTemplateCaller + ) and isinstance(c.bmreq, TritonGPUBenchmarkRequest) + if triton_cuda_choice and async_compile.use_process_pool(): + with open(c.bmreq.module_path) as file: + source_code = file.read() + future = async_compile.triton( + kernel_name=c.bmreq.kernel_name, source_code=source_code + ).future + else: + future = executor.submit(precompile_with_captured_stdout, c) + + start_times[future] = time.time() + future.add_done_callback(on_complete) futures[future] = c @functools.lru_cache(None) @@ -1780,7 +1801,7 @@ class AlgorithmSelectorCache(PersistentCache): log.info( "Precompiling benchmark choice %s took %.02fs", futures[future], - future.result(), + elapsed_times[future], ) executor.shutdown(wait=True)