mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[inductor] Parallelize Max Autotune step 1: Use Popen (#107982)
Summary: Step 1 in revamping subprocess autotune to support multiple GPUs: use Popen to create a new process with an entry point we control so we don't reinterpret the toplevel script. Test Plan: `python test/inductor/test_max_autotune.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/107982 Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
parent
89eb7a75a2
commit
d685668003
|
|
@ -1,13 +1,14 @@
|
|||
import dataclasses
|
||||
import queue
|
||||
import logging
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing
|
||||
from torch._dynamo.testing import rand_strided
|
||||
|
||||
from torch._inductor import ir
|
||||
|
|
@ -19,9 +20,11 @@ if TYPE_CHECKING:
|
|||
from .utils import do_bench
|
||||
from .virtualized import V
|
||||
|
||||
DEBUG = False
|
||||
|
||||
EXIT_HANDLER_REGISTERED = False
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Used to synchronize between parent and child processes
|
||||
class Ping:
|
||||
|
|
@ -34,58 +37,62 @@ class Pong:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class TuningProcess:
|
||||
process: Optional[BaseProcess] = None
|
||||
request_queue: Optional["Queue[Any]"] = None
|
||||
response_queue: Optional["Queue[Any]"] = None
|
||||
"""
|
||||
Abstraction for launching a helper process to benchmark kernels. Rather
|
||||
than spawning the parent process, the approach Popens a new process with
|
||||
an entry point that we control. Avoiding the spawn means we do not re-enter
|
||||
the toplevel script. The subprocess communicates with the parent process
|
||||
via pickling requests/responses over stdin/stdout pipes.
|
||||
"""
|
||||
|
||||
process: Optional["subprocess.Popen[bytes]"] = None
|
||||
|
||||
@staticmethod
|
||||
def process_main(
|
||||
request_queue: "Queue[Any]",
|
||||
response_queue: "Queue[Any]",
|
||||
) -> None:
|
||||
print("enter child process main")
|
||||
while True:
|
||||
obj = request_queue.get()
|
||||
def process_main() -> None:
|
||||
"""
|
||||
Entry point for the child process.
|
||||
"""
|
||||
log.debug("Entering TuningProcess child main")
|
||||
try:
|
||||
TuningProcess.workloop()
|
||||
except Exception:
|
||||
log.exception("Exception in TuningProcess")
|
||||
|
||||
@staticmethod
|
||||
def workloop() -> None:
|
||||
"""
|
||||
Work loop for the benchmarking subprocess.
|
||||
"""
|
||||
|
||||
def reply(obj):
|
||||
# Note this is subtly different than the put() method below.
|
||||
pickle.dump(obj, sys.stdout.buffer)
|
||||
sys.stdout.flush()
|
||||
|
||||
while True:
|
||||
obj = pickle.load(sys.stdin.buffer)
|
||||
if obj is None:
|
||||
break # None is a sentinel for the child to terminate
|
||||
# None is a sentinel for the child to terminate
|
||||
break
|
||||
elif isinstance(obj, Ping):
|
||||
response_queue.put(Pong())
|
||||
reply(Pong())
|
||||
elif isinstance(obj, BenchmarkRequest):
|
||||
response_queue.put(obj.benchmark())
|
||||
reply(obj.benchmark())
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request type {type(obj)}")
|
||||
|
||||
def valid(self) -> bool:
|
||||
return (
|
||||
self.process is not None
|
||||
and self.request_queue is not None
|
||||
and self.response_queue is not None
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.process = self.request_queue = self.response_queue = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Create child process, request/response queues and do the warm up.
|
||||
Create child process and do the warm up.
|
||||
"""
|
||||
if self.valid():
|
||||
if self.process is not None:
|
||||
return
|
||||
|
||||
# cuda runtime does not work with "fork", use "spawn" to start processes.
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
request_queue = self.request_queue = ctx.Queue()
|
||||
response_queue = self.response_queue = ctx.Queue()
|
||||
|
||||
process = self.process = ctx.Process(
|
||||
target=self.process_main,
|
||||
args=(
|
||||
self.request_queue,
|
||||
self.response_queue,
|
||||
),
|
||||
self.process = subprocess.Popen(
|
||||
[sys.executable, "-m", "torch._inductor.autotune_process_entry"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
process.start()
|
||||
|
||||
# register the exit handler for the parent process so it will terminate
|
||||
# the child processes
|
||||
|
|
@ -97,18 +104,58 @@ class TuningProcess:
|
|||
atexit.register(lambda: self.terminate())
|
||||
|
||||
# wait for the initialization to be done
|
||||
request_queue.put(Ping())
|
||||
resp = response_queue.get()
|
||||
self.put(Ping())
|
||||
resp = self.get()
|
||||
assert isinstance(resp, Pong)
|
||||
|
||||
def put(self, obj: Any) -> None:
|
||||
"""
|
||||
Push a work item to the child process.
|
||||
"""
|
||||
# In case of a prior crash, ensure the subprocess is running
|
||||
self.initialize()
|
||||
assert self.process is not None
|
||||
assert self.process.stdin is not None
|
||||
pickle.dump(obj, self.process.stdin)
|
||||
self.process.stdin.flush()
|
||||
|
||||
def get(self) -> Any:
|
||||
"""
|
||||
Get a response from the child process.
|
||||
"""
|
||||
assert self.process is not None
|
||||
assert self.process.stdout is not None
|
||||
try:
|
||||
return pickle.load(self.process.stdout)
|
||||
except EOFError:
|
||||
# Child crashed; clean up
|
||||
self.close()
|
||||
raise
|
||||
except pickle.UnpicklingError as ex:
|
||||
raise RuntimeError(
|
||||
"Error deserializing response from the benchmarking subprocess. "
|
||||
"Is the benchmark code path writing to stdout?"
|
||||
) from ex
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the communication pipes from the child process.
|
||||
"""
|
||||
if self.process is not None:
|
||||
assert self.process.stdin is not None
|
||||
assert self.process.stdout is not None
|
||||
self.process.stdin.close()
|
||||
self.process.stdout.close()
|
||||
self.process = None
|
||||
|
||||
def terminate(self) -> None:
|
||||
if self.valid():
|
||||
request_queue = self.request_queue
|
||||
assert request_queue is not None
|
||||
request_queue.put(None)
|
||||
process = self.process
|
||||
assert process is not None
|
||||
process.join()
|
||||
"""
|
||||
Signal the child process to terminate and wait for it to exit.
|
||||
"""
|
||||
if self.process is not None:
|
||||
self.put(None)
|
||||
self.process.wait()
|
||||
self.close()
|
||||
|
||||
|
||||
tuning_process = TuningProcess()
|
||||
|
|
@ -180,18 +227,20 @@ class BenchmarkRequest:
|
|||
def benchmark(
|
||||
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
||||
) -> float:
|
||||
if DEBUG:
|
||||
debug = log.isEnabledFor(logging.DEBUG)
|
||||
if debug:
|
||||
start_ts = time.time()
|
||||
|
||||
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
||||
if DEBUG:
|
||||
print(
|
||||
f"benchmark module key: {self.module_cache_key}, path: {self.module_path}"
|
||||
)
|
||||
log.debug(
|
||||
"benchmark module key: %s, path: %s",
|
||||
self.module_cache_key,
|
||||
self.module_path,
|
||||
)
|
||||
|
||||
run = getattr(mod, self.kernel_name).run
|
||||
|
||||
if DEBUG:
|
||||
if debug:
|
||||
load_elapse = time.time() - start_ts
|
||||
start_ts = time.time()
|
||||
|
||||
|
|
@ -205,7 +254,7 @@ class BenchmarkRequest:
|
|||
assert isinstance(self.output_tensor, TensorMeta)
|
||||
output_tensor = self.output_tensor.to_tensor()
|
||||
|
||||
if DEBUG:
|
||||
if debug:
|
||||
create_tensor_elapse = time.time() - start_ts
|
||||
start_ts = time.time()
|
||||
|
||||
|
|
@ -222,12 +271,16 @@ class BenchmarkRequest:
|
|||
out = do_bench(worker)
|
||||
torch.cuda.synchronize() # shake out any CUDA errors
|
||||
|
||||
if DEBUG:
|
||||
if debug:
|
||||
bench_elapse = time.time() - start_ts
|
||||
print(
|
||||
f"InChidProcess {self.module_cache_key}: load {load_elapse}, "
|
||||
+ f"create tensor {create_tensor_elapse}, bench {bench_elapse}"
|
||||
log.debug(
|
||||
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
||||
self.module_cache_key,
|
||||
load_elapse,
|
||||
create_tensor_elapse,
|
||||
bench_elapse,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -239,35 +292,15 @@ def benchmark_in_sub_process(
|
|||
"""
|
||||
assert choice.bmreq is not None
|
||||
tuning_process.initialize()
|
||||
assert tuning_process.valid()
|
||||
process, request_queue, response_queue = (
|
||||
tuning_process.process,
|
||||
tuning_process.request_queue,
|
||||
tuning_process.response_queue,
|
||||
)
|
||||
assert (
|
||||
process is not None and request_queue is not None and response_queue is not None
|
||||
)
|
||||
|
||||
request_queue.put(choice.bmreq)
|
||||
while True:
|
||||
try:
|
||||
timing = response_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
status = process.exitcode
|
||||
if status is None:
|
||||
# child process is still running
|
||||
continue
|
||||
# child process fail
|
||||
assert status != 0
|
||||
|
||||
warnings.warn(
|
||||
f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long
|
||||
)
|
||||
|
||||
tuning_process.clear()
|
||||
|
||||
# return INF so this choice will be ignored
|
||||
return float("inf")
|
||||
|
||||
return timing
|
||||
tuning_process.put(choice.bmreq)
|
||||
try:
|
||||
return tuning_process.get()
|
||||
except EOFError:
|
||||
warnings.warn(
|
||||
f"Failed to benchmark choice '{choice}'. It will be ignored. "
|
||||
"Please debug the root cause in case the choice can bring perf gains.",
|
||||
stacklevel=2,
|
||||
)
|
||||
# return INF so this choice will be ignored
|
||||
return float("inf")
|
||||
|
|
|
|||
5
torch/_inductor/autotune_process_entry.py
Normal file
5
torch/_inductor/autotune_process_entry.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from torch._inductor.autotune_process import TuningProcess
|
||||
|
||||
# Entry point for the subprocess supporting the TuningProcess's benchmark operation.
|
||||
if __name__ == "__main__":
|
||||
TuningProcess.process_main()
|
||||
|
|
@ -798,7 +798,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
|
||||
# do the optional warmup
|
||||
tuning_process.initialize()
|
||||
assert tuning_process.valid()
|
||||
|
||||
autotune_start_ts = time.time()
|
||||
timings = self.lookup(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user