[inductor] Parallelize Max Autotune step 1: Use Popen (#107982)

Summary: Step 1 in revamping subprocess autotune to support multiple GPUs: use Popen to create a new process with an entry point we control so we don't reinterpret the toplevel script.

Test Plan: `python test/inductor/test_max_autotune.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107982
Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
Sam Larsen 2023-09-09 09:00:43 -07:00 committed by PyTorch MergeBot
parent 89eb7a75a2
commit d685668003
3 changed files with 131 additions and 94 deletions

View File

@ -1,13 +1,14 @@
import dataclasses
import queue
import logging
import pickle
import subprocess
import sys
import time
import warnings
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch import multiprocessing
from torch._dynamo.testing import rand_strided
from torch._inductor import ir
@ -19,9 +20,11 @@ if TYPE_CHECKING:
from .utils import do_bench
from .virtualized import V
DEBUG = False
EXIT_HANDLER_REGISTERED = False
log = logging.getLogger(__name__)
# Used to synchronize between parent and child processes
class Ping:
@ -34,58 +37,62 @@ class Pong:
@dataclasses.dataclass
class TuningProcess:
process: Optional[BaseProcess] = None
request_queue: Optional["Queue[Any]"] = None
response_queue: Optional["Queue[Any]"] = None
"""
Abstraction for launching a helper process to benchmark kernels. Rather
than spawning the parent process, the approach Popens a new process with
an entry point that we control. Avoiding the spawn means we do not re-enter
the toplevel script. The subprocess communicates with the parent process
via pickling requests/responses over stdin/stdout pipes.
"""
process: Optional["subprocess.Popen[bytes]"] = None
@staticmethod
def process_main(
request_queue: "Queue[Any]",
response_queue: "Queue[Any]",
) -> None:
print("enter child process main")
while True:
obj = request_queue.get()
def process_main() -> None:
"""
Entry point for the child process.
"""
log.debug("Entering TuningProcess child main")
try:
TuningProcess.workloop()
except Exception:
log.exception("Exception in TuningProcess")
@staticmethod
def workloop() -> None:
"""
Work loop for the benchmarking subprocess.
"""
def reply(obj):
# Note this is subtly different than the put() method below.
pickle.dump(obj, sys.stdout.buffer)
sys.stdout.flush()
while True:
obj = pickle.load(sys.stdin.buffer)
if obj is None:
break # None is a sentinel for the child to terminate
# None is a sentinel for the child to terminate
break
elif isinstance(obj, Ping):
response_queue.put(Pong())
reply(Pong())
elif isinstance(obj, BenchmarkRequest):
response_queue.put(obj.benchmark())
reply(obj.benchmark())
else:
raise RuntimeError(f"Invalid request type {type(obj)}")
def valid(self) -> bool:
return (
self.process is not None
and self.request_queue is not None
and self.response_queue is not None
)
def clear(self) -> None:
self.process = self.request_queue = self.response_queue = None
def initialize(self) -> None:
"""
Create child process, request/response queues and do the warm up.
Create child process and do the warm up.
"""
if self.valid():
if self.process is not None:
return
# cuda runtime does not work with "fork", use "spawn" to start processes.
ctx = multiprocessing.get_context("spawn")
request_queue = self.request_queue = ctx.Queue()
response_queue = self.response_queue = ctx.Queue()
process = self.process = ctx.Process(
target=self.process_main,
args=(
self.request_queue,
self.response_queue,
),
self.process = subprocess.Popen(
[sys.executable, "-m", "torch._inductor.autotune_process_entry"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
process.start()
# register the exit handler for the parent process so it will terminate
# the child processes
@ -97,18 +104,58 @@ class TuningProcess:
atexit.register(lambda: self.terminate())
# wait for the initialization to be done
request_queue.put(Ping())
resp = response_queue.get()
self.put(Ping())
resp = self.get()
assert isinstance(resp, Pong)
def put(self, obj: Any) -> None:
"""
Push a work item to the child process.
"""
# In case of a prior crash, ensure the subprocess is running
self.initialize()
assert self.process is not None
assert self.process.stdin is not None
pickle.dump(obj, self.process.stdin)
self.process.stdin.flush()
def get(self) -> Any:
"""
Get a response from the child process.
"""
assert self.process is not None
assert self.process.stdout is not None
try:
return pickle.load(self.process.stdout)
except EOFError:
# Child crashed; clean up
self.close()
raise
except pickle.UnpicklingError as ex:
raise RuntimeError(
"Error deserializing response from the benchmarking subprocess. "
"Is the benchmark code path writing to stdout?"
) from ex
def close(self) -> None:
"""
Close the communication pipes from the child process.
"""
if self.process is not None:
assert self.process.stdin is not None
assert self.process.stdout is not None
self.process.stdin.close()
self.process.stdout.close()
self.process = None
def terminate(self) -> None:
if self.valid():
request_queue = self.request_queue
assert request_queue is not None
request_queue.put(None)
process = self.process
assert process is not None
process.join()
"""
Signal the child process to terminate and wait for it to exit.
"""
if self.process is not None:
self.put(None)
self.process.wait()
self.close()
tuning_process = TuningProcess()
@ -180,18 +227,20 @@ class BenchmarkRequest:
def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
if DEBUG:
debug = log.isEnabledFor(logging.DEBUG)
if debug:
start_ts = time.time()
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
if DEBUG:
print(
f"benchmark module key: {self.module_cache_key}, path: {self.module_path}"
)
log.debug(
"benchmark module key: %s, path: %s",
self.module_cache_key,
self.module_path,
)
run = getattr(mod, self.kernel_name).run
if DEBUG:
if debug:
load_elapse = time.time() - start_ts
start_ts = time.time()
@ -205,7 +254,7 @@ class BenchmarkRequest:
assert isinstance(self.output_tensor, TensorMeta)
output_tensor = self.output_tensor.to_tensor()
if DEBUG:
if debug:
create_tensor_elapse = time.time() - start_ts
start_ts = time.time()
@ -222,12 +271,16 @@ class BenchmarkRequest:
out = do_bench(worker)
torch.cuda.synchronize() # shake out any CUDA errors
if DEBUG:
if debug:
bench_elapse = time.time() - start_ts
print(
f"InChidProcess {self.module_cache_key}: load {load_elapse}, "
+ f"create tensor {create_tensor_elapse}, bench {bench_elapse}"
log.debug(
"InChildProcess %s: load %f, create tensor %f, bench %f",
self.module_cache_key,
load_elapse,
create_tensor_elapse,
bench_elapse,
)
return out
@ -239,35 +292,15 @@ def benchmark_in_sub_process(
"""
assert choice.bmreq is not None
tuning_process.initialize()
assert tuning_process.valid()
process, request_queue, response_queue = (
tuning_process.process,
tuning_process.request_queue,
tuning_process.response_queue,
)
assert (
process is not None and request_queue is not None and response_queue is not None
)
request_queue.put(choice.bmreq)
while True:
try:
timing = response_queue.get(timeout=1.0)
except queue.Empty:
status = process.exitcode
if status is None:
# child process is still running
continue
# child process fail
assert status != 0
warnings.warn(
f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long
)
tuning_process.clear()
# return INF so this choice will be ignored
return float("inf")
return timing
tuning_process.put(choice.bmreq)
try:
return tuning_process.get()
except EOFError:
warnings.warn(
f"Failed to benchmark choice '{choice}'. It will be ignored. "
"Please debug the root cause in case the choice can bring perf gains.",
stacklevel=2,
)
# return INF so this choice will be ignored
return float("inf")

View File

@ -0,0 +1,5 @@
from torch._inductor.autotune_process import TuningProcess
# Entry point for the subprocess supporting the TuningProcess's benchmark operation.
if __name__ == "__main__":
TuningProcess.process_main()

View File

@ -798,7 +798,6 @@ class AlgorithmSelectorCache(PersistentCache):
# do the optional warmup
tuning_process.initialize()
assert tuning_process.valid()
autotune_start_ts = time.time()
timings = self.lookup(