From d6856680039e5557b45e4cd6e95f82ca64f6435a Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Sat, 9 Sep 2023 09:00:43 -0700 Subject: [PATCH] [inductor] Parallelize Max Autotune step 1: Use Popen (#107982) Summary: Step 1 in revamping subprocess autotune to support multiple GPUs: use Popen to create a new process with an entry point we control so we don't reinterpret the toplevel script. Test Plan: `python test/inductor/test_max_autotune.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/107982 Approved by: https://github.com/eellison, https://github.com/shunting314 --- torch/_inductor/autotune_process.py | 219 +++++++++++++--------- torch/_inductor/autotune_process_entry.py | 5 + torch/_inductor/select_algorithm.py | 1 - 3 files changed, 131 insertions(+), 94 deletions(-) create mode 100644 torch/_inductor/autotune_process_entry.py diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 66d73749ea7..47d297ce922 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,13 +1,14 @@ import dataclasses -import queue +import logging +import pickle +import subprocess +import sys import time import warnings -from multiprocessing.process import BaseProcess -from multiprocessing.queues import Queue + from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch import multiprocessing from torch._dynamo.testing import rand_strided from torch._inductor import ir @@ -19,9 +20,11 @@ if TYPE_CHECKING: from .utils import do_bench from .virtualized import V -DEBUG = False + EXIT_HANDLER_REGISTERED = False +log = logging.getLogger(__name__) + # Used to synchronize between parent and child processes class Ping: @@ -34,58 +37,62 @@ class Pong: @dataclasses.dataclass class TuningProcess: - process: Optional[BaseProcess] = None - request_queue: Optional["Queue[Any]"] = None - response_queue: Optional["Queue[Any]"] = None + """ + Abstraction for launching a helper process to benchmark kernels. Rather + than spawning the parent process, the approach Popens a new process with + an entry point that we control. Avoiding the spawn means we do not re-enter + the toplevel script. The subprocess communicates with the parent process + via pickling requests/responses over stdin/stdout pipes. + """ + + process: Optional["subprocess.Popen[bytes]"] = None @staticmethod - def process_main( - request_queue: "Queue[Any]", - response_queue: "Queue[Any]", - ) -> None: - print("enter child process main") - while True: - obj = request_queue.get() + def process_main() -> None: + """ + Entry point for the child process. + """ + log.debug("Entering TuningProcess child main") + try: + TuningProcess.workloop() + except Exception: + log.exception("Exception in TuningProcess") + @staticmethod + def workloop() -> None: + """ + Work loop for the benchmarking subprocess. + """ + + def reply(obj): + # Note this is subtly different than the put() method below. + pickle.dump(obj, sys.stdout.buffer) + sys.stdout.flush() + + while True: + obj = pickle.load(sys.stdin.buffer) if obj is None: - break # None is a sentinel for the child to terminate + # None is a sentinel for the child to terminate + break elif isinstance(obj, Ping): - response_queue.put(Pong()) + reply(Pong()) elif isinstance(obj, BenchmarkRequest): - response_queue.put(obj.benchmark()) + reply(obj.benchmark()) else: raise RuntimeError(f"Invalid request type {type(obj)}") - def valid(self) -> bool: - return ( - self.process is not None - and self.request_queue is not None - and self.response_queue is not None - ) - - def clear(self) -> None: - self.process = self.request_queue = self.response_queue = None - def initialize(self) -> None: """ - Create child process, request/response queues and do the warm up. + Create child process and do the warm up. """ - if self.valid(): + if self.process is not None: return - # cuda runtime does not work with "fork", use "spawn" to start processes. - ctx = multiprocessing.get_context("spawn") - request_queue = self.request_queue = ctx.Queue() - response_queue = self.response_queue = ctx.Queue() - - process = self.process = ctx.Process( - target=self.process_main, - args=( - self.request_queue, - self.response_queue, - ), + self.process = subprocess.Popen( + [sys.executable, "-m", "torch._inductor.autotune_process_entry"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, ) - process.start() # register the exit handler for the parent process so it will terminate # the child processes @@ -97,18 +104,58 @@ class TuningProcess: atexit.register(lambda: self.terminate()) # wait for the initialization to be done - request_queue.put(Ping()) - resp = response_queue.get() + self.put(Ping()) + resp = self.get() assert isinstance(resp, Pong) + def put(self, obj: Any) -> None: + """ + Push a work item to the child process. + """ + # In case of a prior crash, ensure the subprocess is running + self.initialize() + assert self.process is not None + assert self.process.stdin is not None + pickle.dump(obj, self.process.stdin) + self.process.stdin.flush() + + def get(self) -> Any: + """ + Get a response from the child process. + """ + assert self.process is not None + assert self.process.stdout is not None + try: + return pickle.load(self.process.stdout) + except EOFError: + # Child crashed; clean up + self.close() + raise + except pickle.UnpicklingError as ex: + raise RuntimeError( + "Error deserializing response from the benchmarking subprocess. " + "Is the benchmark code path writing to stdout?" + ) from ex + + def close(self) -> None: + """ + Close the communication pipes from the child process. + """ + if self.process is not None: + assert self.process.stdin is not None + assert self.process.stdout is not None + self.process.stdin.close() + self.process.stdout.close() + self.process = None + def terminate(self) -> None: - if self.valid(): - request_queue = self.request_queue - assert request_queue is not None - request_queue.put(None) - process = self.process - assert process is not None - process.join() + """ + Signal the child process to terminate and wait for it to exit. + """ + if self.process is not None: + self.put(None) + self.process.wait() + self.close() tuning_process = TuningProcess() @@ -180,18 +227,20 @@ class BenchmarkRequest: def benchmark( self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None ) -> float: - if DEBUG: + debug = log.isEnabledFor(logging.DEBUG) + if debug: start_ts = time.time() mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) - if DEBUG: - print( - f"benchmark module key: {self.module_cache_key}, path: {self.module_path}" - ) + log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) run = getattr(mod, self.kernel_name).run - if DEBUG: + if debug: load_elapse = time.time() - start_ts start_ts = time.time() @@ -205,7 +254,7 @@ class BenchmarkRequest: assert isinstance(self.output_tensor, TensorMeta) output_tensor = self.output_tensor.to_tensor() - if DEBUG: + if debug: create_tensor_elapse = time.time() - start_ts start_ts = time.time() @@ -222,12 +271,16 @@ class BenchmarkRequest: out = do_bench(worker) torch.cuda.synchronize() # shake out any CUDA errors - if DEBUG: + if debug: bench_elapse = time.time() - start_ts - print( - f"InChidProcess {self.module_cache_key}: load {load_elapse}, " - + f"create tensor {create_tensor_elapse}, bench {bench_elapse}" + log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + self.module_cache_key, + load_elapse, + create_tensor_elapse, + bench_elapse, ) + return out @@ -239,35 +292,15 @@ def benchmark_in_sub_process( """ assert choice.bmreq is not None tuning_process.initialize() - assert tuning_process.valid() - process, request_queue, response_queue = ( - tuning_process.process, - tuning_process.request_queue, - tuning_process.response_queue, - ) - assert ( - process is not None and request_queue is not None and response_queue is not None - ) - request_queue.put(choice.bmreq) - while True: - try: - timing = response_queue.get(timeout=1.0) - except queue.Empty: - status = process.exitcode - if status is None: - # child process is still running - continue - # child process fail - assert status != 0 - - warnings.warn( - f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long - ) - - tuning_process.clear() - - # return INF so this choice will be ignored - return float("inf") - - return timing + tuning_process.put(choice.bmreq) + try: + return tuning_process.get() + except EOFError: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains.", + stacklevel=2, + ) + # return INF so this choice will be ignored + return float("inf") diff --git a/torch/_inductor/autotune_process_entry.py b/torch/_inductor/autotune_process_entry.py new file mode 100644 index 00000000000..0707db8efc0 --- /dev/null +++ b/torch/_inductor/autotune_process_entry.py @@ -0,0 +1,5 @@ +from torch._inductor.autotune_process import TuningProcess + +# Entry point for the subprocess supporting the TuningProcess's benchmark operation. +if __name__ == "__main__": + TuningProcess.process_main() diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 64495368a6c..61ecf7a1e65 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -798,7 +798,6 @@ class AlgorithmSelectorCache(PersistentCache): # do the optional warmup tuning_process.initialize() - assert tuning_process.valid() autotune_start_ts = time.time() timings = self.lookup(