mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize load inline via pch (#106696)
Add PreCompiled Header(PCH) to reduce load_inline build time. PCH is gcc built-in mechanism: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html Add PCH for '#include <torch/extension.h>'. This file will used in all load_inline modules. All load_inline modules can take benifit from this PR. Changes: 1. Add PCH signature to guarantee PCH(gch) file take effect. 2. Unification get cxx compiler funtions. 3. Unification get build flags funtions. Before this PR:  Added this PR:  Compiling time is reduced from 14.06s to 7.36s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106696 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
24968383b5
commit
5ed60477a7
|
|
@ -18,12 +18,13 @@ import torch.utils.cpp_extension
|
|||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torch.utils.cpp_extension import _TORCH_PATH, remove_extension_h_precompiler_headers, get_cxx_compiler, check_compiler_is_gcc
|
||||
|
||||
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||
TEST_MPS = torch.backends.mps.is_available()
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
IS_LINUX = sys.platform.startswith('linux')
|
||||
|
||||
|
||||
def remove_build_path():
|
||||
|
|
@ -919,5 +920,42 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
torch.func.grad(identity_m.identity)(t)
|
||||
|
||||
|
||||
def test_gen_extension_h_pch(self):
|
||||
if not IS_LINUX:
|
||||
return
|
||||
|
||||
source = """
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
||||
return x.sin() + y.sin();
|
||||
}
|
||||
"""
|
||||
|
||||
head_file_pch = os.path.join(_TORCH_PATH, "include", "torch", "extension.h.gch")
|
||||
head_file_signature = os.path.join(
|
||||
_TORCH_PATH, "include", "torch", "extension.h.sign"
|
||||
)
|
||||
|
||||
remove_extension_h_precompiler_headers()
|
||||
pch_exist = os.path.exists(head_file_pch)
|
||||
signature_exist = os.path.exists(head_file_signature)
|
||||
self.assertEqual(pch_exist, False)
|
||||
self.assertEqual(signature_exist, False)
|
||||
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name="inline_extension_with_pch",
|
||||
cpp_sources=[source],
|
||||
functions=["sin_add"],
|
||||
verbose=True,
|
||||
use_pch=True,
|
||||
)
|
||||
pch_exist = os.path.exists(head_file_pch)
|
||||
signature_exist = os.path.exists(head_file_signature)
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
if check_compiler_is_gcc(compiler):
|
||||
self.assertEqual(pch_exist, True)
|
||||
self.assertEqual(signature_exist, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
|
|||
|
|
@ -1152,6 +1152,7 @@ class CppWrapperCodeCache:
|
|||
extra_cflags=[extra_cflags],
|
||||
extra_ldflags=[extra_ldflags],
|
||||
extra_include_paths=[extra_include_paths],
|
||||
use_pch=True,
|
||||
)
|
||||
log.debug("Cpp wrapper done building %s", filepath)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import sys
|
|||
import sysconfig
|
||||
import warnings
|
||||
import collections
|
||||
from pathlib import Path
|
||||
import errno
|
||||
|
||||
import torch
|
||||
import torch._appdirs
|
||||
|
|
@ -74,7 +76,7 @@ CUDA_CLANG_VERSIONS: VersionMap = {
|
|||
|
||||
__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
|
||||
"CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
|
||||
"verify_ninja_availability"]
|
||||
"verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
|
||||
# Taken directly from python stdlib < 3.9
|
||||
# See https://github.com/pytorch/pytorch/issues/48617
|
||||
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
|
||||
|
|
@ -249,6 +251,12 @@ PLAT_TO_VCVARS = {
|
|||
'win-amd64' : 'x86_amd64',
|
||||
}
|
||||
|
||||
def get_cxx_compiler():
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
else:
|
||||
compiler = os.environ.get('CXX', 'c++')
|
||||
return compiler
|
||||
|
||||
def _is_binary_build() -> bool:
|
||||
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
|
||||
|
|
@ -884,10 +892,8 @@ class BuildExtension(build_ext):
|
|||
# On some platforms, like Windows, compiler_cxx is not available.
|
||||
if hasattr(self.compiler, 'compiler_cxx'):
|
||||
compiler = self.compiler.compiler_cxx[0]
|
||||
elif IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
else:
|
||||
compiler = os.environ.get('CXX', 'c++')
|
||||
compiler = get_cxx_compiler()
|
||||
_, version = get_compiler_abi_compatibility_and_version(compiler)
|
||||
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
|
||||
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
|
||||
|
|
@ -1313,6 +1319,180 @@ def load(name,
|
|||
is_standalone,
|
||||
keep_intermediates=keep_intermediates)
|
||||
|
||||
def _get_pybind11_abi_build_flags():
|
||||
# Note [Pybind11 ABI constants]
|
||||
#
|
||||
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
|
||||
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
|
||||
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
|
||||
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
|
||||
#
|
||||
# This was done in order to further narrow down the chances of compiler ABI incompatibility
|
||||
# that can cause a hard to debug segfaults.
|
||||
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
|
||||
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
|
||||
|
||||
abi_cflags = []
|
||||
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
|
||||
pval = getattr(torch._C, f"_PYBIND11_{pname}")
|
||||
if pval is not None and not IS_WINDOWS:
|
||||
abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
|
||||
return abi_cflags
|
||||
|
||||
def _get_glibcxx_abi_build_flags():
|
||||
glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
|
||||
return glibcxx_abi_cflags
|
||||
|
||||
def check_compiler_is_gcc(compiler):
|
||||
if not IS_LINUX:
|
||||
return False
|
||||
|
||||
env = os.environ.copy()
|
||||
env['LC_ALL'] = 'C' # Don't localize output
|
||||
version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
|
||||
# Check for 'gcc' or 'g++' for sccache wrapper
|
||||
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
|
||||
results = re.findall(pattern, version_string)
|
||||
if len(results) != 1:
|
||||
return False
|
||||
compiler_path = os.path.realpath(results[0].strip())
|
||||
# On RHEL/CentOS c++ is a gcc compiler wrapper
|
||||
if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_and_build_extension_h_precompiler_headers(
|
||||
extra_cflags,
|
||||
extra_include_paths,
|
||||
is_standalone=False):
|
||||
r'''
|
||||
Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules.
|
||||
GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html
|
||||
PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need
|
||||
add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild
|
||||
PCH file.
|
||||
|
||||
Note:
|
||||
1. Windows and MacOS have different PCH mechanism. We only support Linux currently.
|
||||
2. It only works on GCC/G++.
|
||||
'''
|
||||
if not IS_LINUX:
|
||||
return
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
b_is_gcc = check_compiler_is_gcc(compiler)
|
||||
if b_is_gcc is False:
|
||||
return
|
||||
|
||||
head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h')
|
||||
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
|
||||
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
|
||||
|
||||
def listToString(s):
|
||||
# initialize an empty string
|
||||
string = ""
|
||||
if s is None:
|
||||
return string
|
||||
|
||||
# traverse in the string
|
||||
for element in s:
|
||||
string += (element + ' ')
|
||||
# return string
|
||||
return string
|
||||
|
||||
def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths):
|
||||
return re.sub(
|
||||
r"[ \n]+",
|
||||
" ",
|
||||
f"""
|
||||
{compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags}
|
||||
""",
|
||||
).strip()
|
||||
|
||||
def command_to_signature(cmd):
|
||||
signature = cmd.replace(' ', '_')
|
||||
return signature
|
||||
|
||||
def check_pch_signature_in_file(file_path, signature):
|
||||
b_exist = os.path.isfile(file_path)
|
||||
if b_exist is False:
|
||||
return False
|
||||
|
||||
with open(file_path) as file:
|
||||
# read all content of a file
|
||||
content = file.read()
|
||||
# check if string present in a file
|
||||
if signature == content:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _create_if_not_exist(path_dir):
|
||||
if not os.path.exists(path_dir):
|
||||
try:
|
||||
Path(path_dir).mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc: # Guard against race condition
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise RuntimeError(f"Fail to create path {path_dir}")
|
||||
|
||||
def write_pch_signature_to_file(file_path, pch_sign):
|
||||
_create_if_not_exist(os.path.dirname(file_path))
|
||||
with open(file_path, "w") as f:
|
||||
f.write(pch_sign)
|
||||
f.close()
|
||||
|
||||
def build_precompile_header(pch_cmd):
|
||||
try:
|
||||
subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}")
|
||||
|
||||
extra_cflags_str = listToString(extra_cflags)
|
||||
extra_include_paths_str = listToString(extra_include_paths)
|
||||
|
||||
lib_include = os.path.join(_TORCH_PATH, 'include')
|
||||
torch_include_dirs = [
|
||||
f"-I {lib_include}",
|
||||
# Python.h
|
||||
"-I {}".format(sysconfig.get_path("include")),
|
||||
# torch/all.h
|
||||
"-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')),
|
||||
]
|
||||
|
||||
torch_include_dirs_str = listToString(torch_include_dirs)
|
||||
|
||||
common_cflags = []
|
||||
if not is_standalone:
|
||||
common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H']
|
||||
|
||||
common_cflags += ['-std=c++17', '-fPIC']
|
||||
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
|
||||
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
|
||||
common_cflags_str = listToString(common_cflags)
|
||||
|
||||
pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str)
|
||||
pch_sign = command_to_signature(pch_cmd)
|
||||
|
||||
if os.path.isfile(head_file_pch) is not True:
|
||||
build_precompile_header(pch_cmd)
|
||||
write_pch_signature_to_file(head_file_signature, pch_sign)
|
||||
else:
|
||||
b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign)
|
||||
if b_same_sign is False:
|
||||
build_precompile_header(pch_cmd)
|
||||
write_pch_signature_to_file(head_file_signature, pch_sign)
|
||||
|
||||
def remove_extension_h_precompiler_headers():
|
||||
def _remove_if_file_exists(path_file):
|
||||
if os.path.exists(path_file):
|
||||
os.remove(path_file)
|
||||
|
||||
head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
|
||||
head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
|
||||
|
||||
_remove_if_file_exists(head_file_pch)
|
||||
_remove_if_file_exists(head_file_signature)
|
||||
|
||||
def load_inline(name,
|
||||
cpp_sources,
|
||||
|
|
@ -1327,7 +1507,8 @@ def load_inline(name,
|
|||
with_cuda=None,
|
||||
is_python_module=True,
|
||||
with_pytorch_error_handling=True,
|
||||
keep_intermediates=True):
|
||||
keep_intermediates=True,
|
||||
use_pch=False):
|
||||
r'''
|
||||
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
|
||||
|
||||
|
|
@ -1409,6 +1590,10 @@ def load_inline(name,
|
|||
|
||||
cpp_sources.insert(0, '#include <torch/extension.h>')
|
||||
|
||||
if use_pch is True:
|
||||
# Using PreCompile Header('torch/extension.h') to reduce compile time.
|
||||
_check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths)
|
||||
|
||||
# If `functions` is supplied, we create the pybind11 bindings for the user.
|
||||
# Here, `functions` is (or becomes, after some processing) a map from
|
||||
# function names to function docstrings.
|
||||
|
|
@ -1561,10 +1746,9 @@ def _write_ninja_file_and_compile_objects(
|
|||
verbose: bool,
|
||||
with_cuda: Optional[bool]) -> None:
|
||||
verify_ninja_availability()
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
else:
|
||||
compiler = os.environ.get('CXX', 'c++')
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
get_compiler_abi_compatibility_and_version(compiler)
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
|
|
@ -1605,10 +1789,9 @@ def _write_ninja_file_and_build_library(
|
|||
with_cuda: Optional[bool],
|
||||
is_standalone: bool = False) -> None:
|
||||
verify_ninja_availability()
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
else:
|
||||
compiler = os.environ.get('CXX', 'c++')
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
get_compiler_abi_compatibility_and_version(compiler)
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
|
|
@ -1993,27 +2176,12 @@ def _write_ninja_file_to_build_library(path,
|
|||
common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
|
||||
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
|
||||
|
||||
# Note [Pybind11 ABI constants]
|
||||
#
|
||||
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
|
||||
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
|
||||
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
|
||||
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
|
||||
#
|
||||
# This was done in order to further narrow down the chances of compiler ABI incompatibility
|
||||
# that can cause a hard to debug segfaults.
|
||||
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
|
||||
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
|
||||
|
||||
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
|
||||
pval = getattr(torch._C, f"_PYBIND11_{pname}")
|
||||
if pval is not None and not IS_WINDOWS:
|
||||
common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
|
||||
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
|
||||
|
||||
common_cflags += [f'-I{include}' for include in user_includes]
|
||||
common_cflags += [f'-isystem {include}' for include in system_includes]
|
||||
|
||||
common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
|
||||
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
|
||||
|
||||
if IS_WINDOWS:
|
||||
cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
|
||||
|
|
@ -2125,10 +2293,7 @@ def _write_ninja_file(path,
|
|||
assert len(sources) == len(objects)
|
||||
assert len(sources) > 0
|
||||
|
||||
if IS_WINDOWS:
|
||||
compiler = os.environ.get('CXX', 'cl')
|
||||
else:
|
||||
compiler = os.environ.get('CXX', 'c++')
|
||||
compiler = get_cxx_compiler()
|
||||
|
||||
# Version 1.3 is required for the `deps` directive.
|
||||
config = ['ninja_required_version = 1.3']
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user