mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is the continuation to #90134 and hopefully the final PR in this series. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271 Approved by: https://github.com/kit1980
648 lines
20 KiB
Python
648 lines
20 KiB
Python
import base64
|
|
import dataclasses
|
|
import functools
|
|
import getpass
|
|
import hashlib
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import shutil
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import sysconfig
|
|
import tempfile
|
|
import types
|
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|
from ctypes import cdll
|
|
from threading import Thread
|
|
from time import sleep, time
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
import torch
|
|
from torch.utils import cpp_extension
|
|
from . import config, cuda_properties, exc
|
|
|
|
LOCK_TIMEOUT = 600
|
|
|
|
# timing metrics for time spent in the compilation
|
|
_cumulative_compile_time = 0
|
|
_t0 = None
|
|
|
|
|
|
def _compile_start():
|
|
global _t0
|
|
if _t0 is None:
|
|
_t0 = time()
|
|
|
|
|
|
def _compile_end():
|
|
global _cumulative_compile_time, _t0
|
|
if _t0 is not None:
|
|
t1 = time()
|
|
_cumulative_compile_time += t1 - _t0
|
|
_t0 = None
|
|
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def cache_dir():
|
|
return os.environ.get(
|
|
"TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}"
|
|
)
|
|
|
|
|
|
def get_lock_dir():
|
|
lock_dir = os.path.join(cache_dir(), "locks")
|
|
if not os.path.exists(lock_dir):
|
|
os.makedirs(lock_dir, exist_ok=True)
|
|
return lock_dir
|
|
|
|
|
|
def code_hash(code):
|
|
return (
|
|
"c"
|
|
+ base64.b32encode(hashlib.sha256(code.encode("utf-8")).digest())[:51]
|
|
.decode("utf-8")
|
|
.lower()
|
|
)
|
|
|
|
|
|
def get_code_path(source_code, ext, extra):
|
|
basename = code_hash(source_code + extra)
|
|
subdir = os.path.join(cache_dir(), basename[1:3])
|
|
path = os.path.join(subdir, f"{basename}.{ext}")
|
|
return basename, subdir, path
|
|
|
|
|
|
def write(source_code, ext, extra=""):
|
|
basename, subdir, path = get_code_path(source_code, ext, extra)
|
|
if not os.path.exists(subdir):
|
|
os.makedirs(subdir, exist_ok=True)
|
|
if not os.path.exists(path):
|
|
# use a temp file for thread safety
|
|
fd, tmp_path = tempfile.mkstemp(dir=subdir)
|
|
with os.fdopen(fd, "w") as f:
|
|
f.write(source_code)
|
|
os.rename(tmp_path, path)
|
|
return basename, path
|
|
|
|
|
|
def cpp_compiler():
|
|
if isinstance(config.cpp.cxx, (list, tuple)):
|
|
search = tuple(config.cpp.cxx)
|
|
else:
|
|
search = (config.cpp.cxx,)
|
|
return cpp_compiler_search(search)
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def cpp_compiler_search(search):
|
|
for cxx in search:
|
|
try:
|
|
if cxx is None:
|
|
# gxx package is only available for Linux
|
|
# according to https://anaconda.org/conda-forge/gxx/
|
|
if sys.platform != "linux":
|
|
continue
|
|
# Do not install GXX by default
|
|
if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
|
|
continue
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(
|
|
os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
|
|
)
|
|
with lock:
|
|
cxx = install_gcc_via_conda()
|
|
subprocess.check_output([cxx, "--version"])
|
|
return cxx
|
|
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
|
|
continue
|
|
raise exc.InvalidCxxCompiler()
|
|
|
|
|
|
def install_gcc_via_conda():
|
|
"""On older systems, this is a quick way to get a modern compiler"""
|
|
prefix = os.path.join(cache_dir(), "gcc")
|
|
cxx_path = os.path.join(prefix, "bin", "g++")
|
|
if not os.path.exists(cxx_path):
|
|
log.info("Downloading GCC via conda")
|
|
conda = os.environ.get("CONDA_EXE", "conda")
|
|
if conda is None:
|
|
conda = shutil.which("conda")
|
|
if conda is not None:
|
|
subprocess.check_call(
|
|
[
|
|
conda,
|
|
"create",
|
|
f"--prefix={prefix}",
|
|
"--channel=conda-forge",
|
|
"--quiet",
|
|
"-y",
|
|
"python=3.8",
|
|
"gxx",
|
|
],
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
return cxx_path
|
|
|
|
|
|
def is_gcc():
|
|
return re.search(r"(gcc|g\+\+)", cpp_compiler())
|
|
|
|
|
|
class VecISA(object):
|
|
_bit_width: int
|
|
_macro: str
|
|
_arch_flags: str
|
|
_dtype_nelements: Dict[torch.dtype, int]
|
|
|
|
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
|
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
|
|
# like exp, pow, sin, cos and etc.
|
|
# But PyTorch and TorchInductor might use different compilers to build code. If
|
|
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
|
|
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
|
|
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
|
|
# gcc/g++ compiler by default while it could support the AVX512 compilation.
|
|
# Therefore, there would be a conflict sleef version between PyTorch and
|
|
# TorchInductor. Hence, we dry-compile the following code to check whether current
|
|
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
|
|
# also needs the logic
|
|
_avx_code = """
|
|
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
|
|
#include <ATen/cpu/vec/functional.h>
|
|
#include <ATen/cpu/vec/vec.h>
|
|
#endif
|
|
|
|
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
|
|
|
|
extern "C" void __avx_chk_kernel() {
|
|
auto tmp0 = at::vec::Vectorized<float>(1);
|
|
auto tmp1 = tmp0.exp();
|
|
tmp1.store(in_out_ptr0);
|
|
}
|
|
"""
|
|
|
|
_avx_py_load = """
|
|
import torch
|
|
from ctypes import cdll
|
|
cdll.LoadLibrary("__lib_path__")
|
|
"""
|
|
|
|
def bit_width(self):
|
|
return self._bit_width
|
|
|
|
def nelements(self, dtype: torch.dtype = torch.float):
|
|
return self._dtype_nelements[dtype]
|
|
|
|
def build_macro(self):
|
|
return self._macro
|
|
|
|
def build_arch_flags(self):
|
|
return self._arch_flags
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(str(self))
|
|
|
|
@functools.lru_cache(None)
|
|
def __bool__(self):
|
|
key, input_path = write(VecISA._avx_code, "cpp", extra="")
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
|
with lock:
|
|
output_path = input_path[:-3] + "so"
|
|
build_cmd = cpp_compile_command(
|
|
input_path, output_path, warning_all=False, vec_isa=self
|
|
).split(" ")
|
|
try:
|
|
# Check build result
|
|
subprocess.check_output(build_cmd, stderr=subprocess.STDOUT)
|
|
subprocess.check_call(
|
|
[
|
|
"python",
|
|
"-c",
|
|
VecISA._avx_py_load.replace("__lib_path__", output_path),
|
|
],
|
|
stderr=subprocess.DEVNULL,
|
|
)
|
|
except Exception as e:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecAVX512(VecISA):
|
|
_bit_width = 512
|
|
_macro = "CPU_CAPABILITY_AVX512"
|
|
_arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
|
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32}
|
|
|
|
def __str__(self) -> str:
|
|
return "avx512"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecAVX2(VecISA):
|
|
_bit_width = 256
|
|
_macro = "CPU_CAPABILITY_AVX2"
|
|
_arch_flags = "-mavx2 -mfma"
|
|
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16}
|
|
|
|
def __str__(self) -> str:
|
|
return "avx2"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
class InvalidVecISA(VecISA):
|
|
_bit_width = 0
|
|
_macro = ""
|
|
_arch_flags = ""
|
|
_dtype_nelements = {}
|
|
|
|
def __str__(self) -> str:
|
|
return "INVALID_VEC_ISA"
|
|
|
|
def __bool__(self):
|
|
return False
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
invalid_vec_isa = InvalidVecISA()
|
|
supported_vec_isa_list = [VecAVX512(), VecAVX2()]
|
|
|
|
|
|
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
|
# might have too much redundant content that is useless for ISA check. Hence,
|
|
# we only cache some key isa information.
|
|
@functools.lru_cache(None)
|
|
def valid_vec_isa_list():
|
|
if sys.platform != "linux":
|
|
return []
|
|
|
|
isa_list = []
|
|
with open("/proc/cpuinfo") as _cpu_info:
|
|
_cpu_info_content = _cpu_info.read()
|
|
for isa in supported_vec_isa_list:
|
|
if str(isa) in _cpu_info_content and isa:
|
|
isa_list.append(isa)
|
|
return isa_list
|
|
|
|
|
|
def pick_vec_isa():
|
|
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
|
if not _valid_vec_isa_list:
|
|
return invalid_vec_isa
|
|
|
|
# If the simdlen is None, it indicates determin the vectroization length automatically
|
|
if config.cpp.simdlen is None:
|
|
assert _valid_vec_isa_list
|
|
return _valid_vec_isa_list[0]
|
|
|
|
for isa in _valid_vec_isa_list:
|
|
if config.cpp.simdlen == isa.bit_width():
|
|
return isa
|
|
|
|
return invalid_vec_isa
|
|
|
|
|
|
def get_shared(shared=True):
|
|
return "-shared -fPIC" if shared else ""
|
|
|
|
|
|
def get_warning_all_flag(warning_all=True):
|
|
return "-Wall" if warning_all else ""
|
|
|
|
|
|
def cpp_flags():
|
|
return "-std=c++14 -Wno-unused-variable"
|
|
|
|
|
|
def optimization_flags():
|
|
return "-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp"
|
|
|
|
|
|
def use_custom_generated_macros():
|
|
return "-D C10_USING_CUSTOM_GENERATED_MACROS"
|
|
|
|
|
|
def get_include_and_linking_paths(
|
|
include_pytorch=False, vec_isa: VecISA = invalid_vec_isa
|
|
):
|
|
if sys.platform == "linux" and (
|
|
include_pytorch
|
|
or vec_isa != invalid_vec_isa
|
|
or config.cpp.enable_kernel_profile
|
|
):
|
|
# Note - We include pytorch only on linux right now. There is more work
|
|
# to do to enable OMP build on darwin where PyTorch is built with IOMP
|
|
# and we need a way to link to what PyTorch links.
|
|
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
|
lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
|
|
libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
|
|
macros = vec_isa.build_macro()
|
|
if macros:
|
|
macros = f"-D{macros}"
|
|
else:
|
|
# Note - this is effectively a header only inclusion. Usage of some header files may result in
|
|
# symbol not found, if those header files require a library.
|
|
# For those cases, include the lpath and libs command as we do for pytorch above.
|
|
# This approach allows us to only pay for what we use.
|
|
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
|
lpaths = []
|
|
libs = ["gomp"]
|
|
macros = ""
|
|
ipaths = " ".join(["-I" + p for p in ipaths])
|
|
lpaths = " ".join(["-L" + p for p in lpaths])
|
|
libs = " ".join(["-l" + p for p in libs])
|
|
return ipaths, lpaths, libs, macros
|
|
|
|
|
|
def cpp_compile_command(
|
|
input,
|
|
output,
|
|
warning_all=True,
|
|
shared=True,
|
|
include_pytorch=False,
|
|
vec_isa: VecISA = invalid_vec_isa,
|
|
):
|
|
ipaths, lpaths, libs, macros = get_include_and_linking_paths(
|
|
include_pytorch, vec_isa
|
|
)
|
|
|
|
return re.sub(
|
|
r"[ \n]+",
|
|
" ",
|
|
f"""
|
|
{cpp_compiler()} {input} {get_shared(shared)} {get_warning_all_flag(warning_all)} {cpp_flags()}
|
|
{ipaths} {lpaths} {libs} {macros}
|
|
{optimization_flags()}
|
|
{use_custom_generated_macros()}
|
|
-o{output}
|
|
""",
|
|
).strip()
|
|
|
|
|
|
class CppCodeCache:
|
|
cache = dict()
|
|
clear = staticmethod(cache.clear)
|
|
|
|
@staticmethod
|
|
def _load_library(path):
|
|
try:
|
|
return cdll.LoadLibrary(path)
|
|
except OSError as e:
|
|
if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
|
|
# hacky workaround for fbcode/buck
|
|
global _libgomp
|
|
_libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
|
|
return cdll.LoadLibrary(path)
|
|
raise
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
picked_vec_isa = pick_vec_isa()
|
|
key, input_path = write(
|
|
source_code,
|
|
"cpp",
|
|
extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa),
|
|
)
|
|
if key not in cls.cache:
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
|
with lock:
|
|
output_path = input_path[:-3] + "so"
|
|
if not os.path.exists(output_path):
|
|
cmd = cpp_compile_command(
|
|
input=input_path, output=output_path, vec_isa=picked_vec_isa
|
|
).split(" ")
|
|
try:
|
|
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
|
except subprocess.CalledProcessError as e:
|
|
raise exc.CppCompileError(cmd, e.output) from e
|
|
|
|
cls.cache[key] = cls._load_library(output_path)
|
|
cls.cache[key].key = key
|
|
|
|
return cls.cache[key]
|
|
|
|
|
|
class PyCodeCache:
|
|
cache = dict()
|
|
clear = staticmethod(cache.clear)
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
key, path = write(source_code, "py")
|
|
if key not in cls.cache:
|
|
with open(path) as f:
|
|
code = compile(f.read(), path, "exec")
|
|
mod = types.ModuleType(f"{__name__}.{key}")
|
|
mod.__file__ = path
|
|
mod.key = key
|
|
exec(code, mod.__dict__, mod.__dict__)
|
|
# another thread might set this first
|
|
cls.cache.setdefault(key, mod)
|
|
return cls.cache[key]
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def patch_triton_dir():
|
|
os.environ["TRITON_CACHE_DIR"] = os.environ.get(
|
|
"TRITON_CACHE_DIR", os.path.join(cache_dir(), "triton")
|
|
)
|
|
|
|
|
|
class TritonCodeCache:
|
|
@staticmethod
|
|
def get_name(mod):
|
|
(name,) = [n for n in dir(mod) if n.startswith("triton_")]
|
|
return name
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
patch_triton_dir()
|
|
mod = PyCodeCache.load(source_code)
|
|
return getattr(mod, cls.get_name(mod))
|
|
|
|
|
|
def _worker_compile(source_code, cc, device):
|
|
cuda_properties.set_compiler_worker_current_device(device)
|
|
kernel = TritonCodeCache.load(source_code)
|
|
kernel.precompile(warm_cache_only_with_cc=cc)
|
|
|
|
|
|
def _load_kernel(source_code):
|
|
kernel = TritonCodeCache.load(source_code)
|
|
kernel.precompile()
|
|
return kernel
|
|
|
|
|
|
def _load_kernel_name(source_code):
|
|
return TritonCodeCache.get_name(PyCodeCache.load(source_code))
|
|
|
|
|
|
class TritonFuture:
|
|
def __init__(self, source_code, future):
|
|
self.source_code = source_code
|
|
self.future = future
|
|
|
|
# @dynamo_utils.dynamo_timed
|
|
def result(self):
|
|
t0 = time()
|
|
if hasattr(self, "kernel"):
|
|
return self.kernel
|
|
# If the worker failed this will throw an exception.
|
|
self.future.result()
|
|
kernel = self.kernel = _load_kernel(self.source_code)
|
|
latency = time() - t0
|
|
if latency > 50:
|
|
name = _load_kernel_name(self.source_code)
|
|
log.warning(
|
|
f"Detected long compilation time of {latency} seconds for kernel name {name}"
|
|
)
|
|
log.warning(self.source_code)
|
|
del self.source_code, self.future
|
|
return kernel
|
|
|
|
|
|
class AsyncCompile:
|
|
def __init__(self):
|
|
self._context_keepalive = None
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def pool():
|
|
assert config.compile_threads > 1
|
|
return ThreadPoolExecutor(config.compile_threads)
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def process_pool():
|
|
# ensure properties have been calculated before processes
|
|
# are forked
|
|
cuda_properties._properties()
|
|
assert config.compile_threads > 1
|
|
orig_ppid = os.getpid()
|
|
|
|
# if this process dies abnormally (e.g. segfault)
|
|
# it will not shut down the workers. Instead
|
|
# the workers will have their parent reassigned to the
|
|
# init process. This launches a separate thread to
|
|
# watch for the worker getting reassigned,
|
|
# and cleans it up in this case.
|
|
def init():
|
|
def run():
|
|
while True:
|
|
sleep(1)
|
|
if orig_ppid != os.getppid():
|
|
os.kill(os.getpid(), signal.SIGKILL)
|
|
|
|
global _watchdog_thread
|
|
_watchdog_thread = Thread(target=run, daemon=True)
|
|
_watchdog_thread.start()
|
|
|
|
# we rely on 'fork' because we cannot control whether users
|
|
# have an `if __name__ == '__main__'` in their main process.
|
|
fork_context = multiprocessing.get_context("fork")
|
|
pool = ProcessPoolExecutor(
|
|
config.compile_threads, mp_context=fork_context, initializer=init
|
|
)
|
|
# when this pool is created in a subprocess object, the normal exit handler
|
|
# doesn't run, and we need to register our own handler.
|
|
# exitpriority has to be high, because another one of the finalizers will
|
|
# kill the worker thread that sends the shutdown message to the workers...
|
|
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
|
return pool
|
|
|
|
@classmethod
|
|
def warm_pool(cls):
|
|
if config.compile_threads <= 1:
|
|
return
|
|
_compile_start()
|
|
pool = cls.process_pool()
|
|
|
|
# We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
|
|
# slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
|
|
|
|
# Examples:
|
|
# A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
|
|
# tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
|
|
|
|
# So we want to start the workers early when it is still cheap, and also to allow the workers to get
|
|
# ready before we have work for them.
|
|
|
|
# ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
|
|
# But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
|
|
|
|
# We force them to start here with some YOLOing of the internal methods.
|
|
if hasattr(pool, "_start_queue_management_thread"):
|
|
pool._start_queue_management_thread()
|
|
else:
|
|
for i in range(config.compile_threads):
|
|
pool._adjust_process_count()
|
|
pool._start_executor_manager_thread()
|
|
_compile_end()
|
|
|
|
@classmethod
|
|
def submit(cls, task):
|
|
if config.compile_threads <= 1:
|
|
return task()
|
|
return cls.pool().submit(task)
|
|
|
|
@classmethod
|
|
def map(cls, fn, seq):
|
|
if config.compile_threads <= 1 or len(seq) <= 1:
|
|
return list(map(fn, seq))
|
|
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
|
|
|
|
def triton(self, source_code):
|
|
_compile_start()
|
|
if self._context_keepalive is None:
|
|
# Workaround `CUDA: Error- context is destroyed`
|
|
self._context_keepalive = torch.tensor([1], device="cuda")
|
|
|
|
if config.compile_threads > 1:
|
|
major, minor = torch.cuda.get_device_capability()
|
|
device = torch.cuda.current_device()
|
|
cc = major * 10 + minor
|
|
future = self.process_pool().submit(
|
|
_worker_compile, source_code, cc, device
|
|
)
|
|
return TritonFuture(source_code, future)
|
|
else:
|
|
return _load_kernel(source_code)
|
|
|
|
def cpp(self, source_code):
|
|
def task():
|
|
return CppCodeCache.load(source_code).kernel
|
|
|
|
return self.submit(task)
|
|
|
|
def wait(self, scope: Dict[str, Any]):
|
|
if config.compile_threads > 1:
|
|
for key, result in list(scope.items()):
|
|
if isinstance(result, (Future, TritonFuture)):
|
|
scope[key] = result.result()
|
|
|
|
_compile_end()
|
|
|
|
|
|
AsyncCompile.warm_pool()
|