Optimize load inline via pch (#106696)

Add PreCompiled Header(PCH) to reduce load_inline build time.
PCH is gcc built-in mechanism: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html

Add PCH for '#include <torch/extension.h>'. This file will used in all load_inline modules. All load_inline modules can take benifit from this PR.

Changes:
1. Add PCH signature to guarantee PCH(gch) file take effect.
2. Unification get cxx compiler funtions.
3. Unification get build flags funtions.

Before this PR:
![image](https://github.com/pytorch/pytorch/assets/8433590/f190cdcb-236c-4312-b165-d419a7efafe3)

Added this PR:
![image](https://github.com/pytorch/pytorch/assets/8433590/b45c5ad3-e902-4fc8-b450-743cf73505a4)

Compiling time is reduced from 14.06s to 7.36s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106696
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
Han, Xu 2023-08-21 10:08:27 +00:00 committed by PyTorch MergeBot
parent 24968383b5
commit 5ed60477a7
3 changed files with 239 additions and 35 deletions

View File

@ -18,12 +18,13 @@ import torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.testing._internal.common_utils import gradcheck
import torch.multiprocessing as mp
from torch.utils.cpp_extension import _TORCH_PATH, remove_extension_h_precompiler_headers, get_cxx_compiler, check_compiler_is_gcc
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
TEST_MPS = torch.backends.mps.is_available()
IS_WINDOWS = sys.platform == "win32"
IS_LINUX = sys.platform.startswith('linux')
def remove_build_path():
@ -919,5 +920,42 @@ class TestCppExtensionJIT(common.TestCase):
with self.assertRaisesRegex(RuntimeError, msg):
torch.func.grad(identity_m.identity)(t)
def test_gen_extension_h_pch(self):
if not IS_LINUX:
return
source = """
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
return x.sin() + y.sin();
}
"""
head_file_pch = os.path.join(_TORCH_PATH, "include", "torch", "extension.h.gch")
head_file_signature = os.path.join(
_TORCH_PATH, "include", "torch", "extension.h.sign"
)
remove_extension_h_precompiler_headers()
pch_exist = os.path.exists(head_file_pch)
signature_exist = os.path.exists(head_file_signature)
self.assertEqual(pch_exist, False)
self.assertEqual(signature_exist, False)
torch.utils.cpp_extension.load_inline(
name="inline_extension_with_pch",
cpp_sources=[source],
functions=["sin_add"],
verbose=True,
use_pch=True,
)
pch_exist = os.path.exists(head_file_pch)
signature_exist = os.path.exists(head_file_signature)
compiler = get_cxx_compiler()
if check_compiler_is_gcc(compiler):
self.assertEqual(pch_exist, True)
self.assertEqual(signature_exist, True)
if __name__ == "__main__":
common.run_tests()

View File

@ -1152,6 +1152,7 @@ class CppWrapperCodeCache:
extra_cflags=[extra_cflags],
extra_ldflags=[extra_ldflags],
extra_include_paths=[extra_include_paths],
use_pch=True,
)
log.debug("Cpp wrapper done building %s", filepath)
else:

View File

@ -12,6 +12,8 @@ import sys
import sysconfig
import warnings
import collections
from pathlib import Path
import errno
import torch
import torch._appdirs
@ -74,7 +76,7 @@ CUDA_CLANG_VERSIONS: VersionMap = {
__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
"CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
"verify_ninja_availability"]
"verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
# Taken directly from python stdlib < 3.9
# See https://github.com/pytorch/pytorch/issues/48617
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
@ -249,6 +251,12 @@ PLAT_TO_VCVARS = {
'win-amd64' : 'x86_amd64',
}
def get_cxx_compiler():
if IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
else:
compiler = os.environ.get('CXX', 'c++')
return compiler
def _is_binary_build() -> bool:
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
@ -884,10 +892,8 @@ class BuildExtension(build_ext):
# On some platforms, like Windows, compiler_cxx is not available.
if hasattr(self.compiler, 'compiler_cxx'):
compiler = self.compiler.compiler_cxx[0]
elif IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
else:
compiler = os.environ.get('CXX', 'c++')
compiler = get_cxx_compiler()
_, version = get_compiler_abi_compatibility_and_version(compiler)
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
@ -1313,6 +1319,180 @@ def load(name,
is_standalone,
keep_intermediates=keep_intermediates)
def _get_pybind11_abi_build_flags():
# Note [Pybind11 ABI constants]
#
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
#
# This was done in order to further narrow down the chances of compiler ABI incompatibility
# that can cause a hard to debug segfaults.
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
abi_cflags = []
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
pval = getattr(torch._C, f"_PYBIND11_{pname}")
if pval is not None and not IS_WINDOWS:
abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
return abi_cflags
def _get_glibcxx_abi_build_flags():
glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
return glibcxx_abi_cflags
def check_compiler_is_gcc(compiler):
if not IS_LINUX:
return False
env = os.environ.copy()
env['LC_ALL'] = 'C' # Don't localize output
version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
# Check for 'gcc' or 'g++' for sccache wrapper
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
results = re.findall(pattern, version_string)
if len(results) != 1:
return False
compiler_path = os.path.realpath(results[0].strip())
# On RHEL/CentOS c++ is a gcc compiler wrapper
if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
return True
return False
def _check_and_build_extension_h_precompiler_headers(
extra_cflags,
extra_include_paths,
is_standalone=False):
r'''
Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules.
GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html
PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need
add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild
PCH file.
Note:
1. Windows and MacOS have different PCH mechanism. We only support Linux currently.
2. It only works on GCC/G++.
'''
if not IS_LINUX:
return
compiler = get_cxx_compiler()
b_is_gcc = check_compiler_is_gcc(compiler)
if b_is_gcc is False:
return
head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h')
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
def listToString(s):
# initialize an empty string
string = ""
if s is None:
return string
# traverse in the string
for element in s:
string += (element + ' ')
# return string
return string
def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths):
return re.sub(
r"[ \n]+",
" ",
f"""
{compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags}
""",
).strip()
def command_to_signature(cmd):
signature = cmd.replace(' ', '_')
return signature
def check_pch_signature_in_file(file_path, signature):
b_exist = os.path.isfile(file_path)
if b_exist is False:
return False
with open(file_path) as file:
# read all content of a file
content = file.read()
# check if string present in a file
if signature == content:
return True
else:
return False
def _create_if_not_exist(path_dir):
if not os.path.exists(path_dir):
try:
Path(path_dir).mkdir(parents=True, exist_ok=True)
except OSError as exc: # Guard against race condition
if exc.errno != errno.EEXIST:
raise RuntimeError(f"Fail to create path {path_dir}")
def write_pch_signature_to_file(file_path, pch_sign):
_create_if_not_exist(os.path.dirname(file_path))
with open(file_path, "w") as f:
f.write(pch_sign)
f.close()
def build_precompile_header(pch_cmd):
try:
subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}")
extra_cflags_str = listToString(extra_cflags)
extra_include_paths_str = listToString(extra_include_paths)
lib_include = os.path.join(_TORCH_PATH, 'include')
torch_include_dirs = [
f"-I {lib_include}",
# Python.h
"-I {}".format(sysconfig.get_path("include")),
# torch/all.h
"-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')),
]
torch_include_dirs_str = listToString(torch_include_dirs)
common_cflags = []
if not is_standalone:
common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H']
common_cflags += ['-std=c++17', '-fPIC']
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
common_cflags_str = listToString(common_cflags)
pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str)
pch_sign = command_to_signature(pch_cmd)
if os.path.isfile(head_file_pch) is not True:
build_precompile_header(pch_cmd)
write_pch_signature_to_file(head_file_signature, pch_sign)
else:
b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign)
if b_same_sign is False:
build_precompile_header(pch_cmd)
write_pch_signature_to_file(head_file_signature, pch_sign)
def remove_extension_h_precompiler_headers():
def _remove_if_file_exists(path_file):
if os.path.exists(path_file):
os.remove(path_file)
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
_remove_if_file_exists(head_file_pch)
_remove_if_file_exists(head_file_signature)
def load_inline(name,
cpp_sources,
@ -1327,7 +1507,8 @@ def load_inline(name,
with_cuda=None,
is_python_module=True,
with_pytorch_error_handling=True,
keep_intermediates=True):
keep_intermediates=True,
use_pch=False):
r'''
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
@ -1409,6 +1590,10 @@ def load_inline(name,
cpp_sources.insert(0, '#include <torch/extension.h>')
if use_pch is True:
# Using PreCompile Header('torch/extension.h') to reduce compile time.
_check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths)
# If `functions` is supplied, we create the pybind11 bindings for the user.
# Here, `functions` is (or becomes, after some processing) a map from
# function names to function docstrings.
@ -1561,10 +1746,9 @@ def _write_ninja_file_and_compile_objects(
verbose: bool,
with_cuda: Optional[bool]) -> None:
verify_ninja_availability()
if IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
else:
compiler = os.environ.get('CXX', 'c++')
compiler = get_cxx_compiler()
get_compiler_abi_compatibility_and_version(compiler)
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
@ -1605,10 +1789,9 @@ def _write_ninja_file_and_build_library(
with_cuda: Optional[bool],
is_standalone: bool = False) -> None:
verify_ninja_availability()
if IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
else:
compiler = os.environ.get('CXX', 'c++')
compiler = get_cxx_compiler()
get_compiler_abi_compatibility_and_version(compiler)
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
@ -1993,27 +2176,12 @@ def _write_ninja_file_to_build_library(path,
common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
# Note [Pybind11 ABI constants]
#
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
#
# This was done in order to further narrow down the chances of compiler ABI incompatibility
# that can cause a hard to debug segfaults.
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
pval = getattr(torch._C, f"_PYBIND11_{pname}")
if pval is not None and not IS_WINDOWS:
common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
common_cflags += [f'-I{include}' for include in user_includes]
common_cflags += [f'-isystem {include}' for include in system_includes]
common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
if IS_WINDOWS:
cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
@ -2125,10 +2293,7 @@ def _write_ninja_file(path,
assert len(sources) == len(objects)
assert len(sources) > 0
if IS_WINDOWS:
compiler = os.environ.get('CXX', 'cl')
else:
compiler = os.environ.get('CXX', 'c++')
compiler = get_cxx_compiler()
# Version 1.3 is required for the `deps` directive.
config = ['ninja_required_version = 1.3']