mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] Add aot_mode UT to new cpp_builder. (#130105)
Changes: 1. Add `aot_mode` parameter to `validate_new_cpp_commands` UT. 2. Switch AotCodeCompiler vec isa command gen to new cpp_builder. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130105 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
d496145534
commit
21eeedb455
|
|
@ -76,6 +76,7 @@ from torch._inductor.cpp_builder import (
|
|||
CppOptions,
|
||||
CppTorchCudaOptions,
|
||||
get_compiler_version_info,
|
||||
get_name_and_dir_from_output_file_path,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import invalid_vec_isa, pick_vec_isa, VecISA
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
|
|
@ -1595,15 +1596,22 @@ class AotCodeCompiler:
|
|||
cuda: bool,
|
||||
) -> str:
|
||||
picked_vec_isa = pick_vec_isa()
|
||||
cpp_command = repr(
|
||||
cpp_compile_command(
|
||||
"i",
|
||||
"o",
|
||||
vec_isa_cmd_gen = CppBuilder(
|
||||
name="o",
|
||||
sources="i",
|
||||
BuildOption=CppTorchCudaOptions(
|
||||
vec_isa=picked_vec_isa,
|
||||
cuda=cuda,
|
||||
aot_mode=graph.aot_mode,
|
||||
),
|
||||
)
|
||||
)
|
||||
# write function will calc source_code hash, the same source code with different
|
||||
# ISA level should be generate different hash.
|
||||
# So we need get a command_line which contains isa related parameter as a part of hash key.
|
||||
# And then pass the command_line to below write function as extra parameter to
|
||||
# guarantee the source code hash contains ISA difference.
|
||||
cpp_command = repr(vec_isa_cmd_gen.get_command_line())
|
||||
|
||||
fbcode_aot_cpu_re = False
|
||||
use_absolute_path = False
|
||||
if config.is_fbcode():
|
||||
|
|
@ -1853,7 +1861,6 @@ class AotCodeCompiler:
|
|||
"linux": _compile_consts_linux,
|
||||
"darwin": _compile_consts_darwin,
|
||||
}[sys.platform](aot_constants)
|
||||
|
||||
link_cmd = cpp_compile_command(
|
||||
input=[output_o, consts_o],
|
||||
output=output_so,
|
||||
|
|
@ -2051,8 +2058,6 @@ class CppCodeCache:
|
|||
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions
|
||||
|
||||
command_gen = CppBuilder(
|
||||
name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command)
|
||||
)
|
||||
|
|
@ -2363,35 +2368,48 @@ def _do_validate_cpp_commands(
|
|||
compile_only: bool,
|
||||
mmap_weights: bool,
|
||||
use_absolute_path: bool,
|
||||
aot_mode: bool,
|
||||
):
|
||||
# PreCI will failed if test machine can't run cuda.
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
test_dir_path = temp_dir.name
|
||||
test_cuda = torch.cuda.is_available() and cuda
|
||||
input_path = os.path.join(test_dir_path, "dummy_input.cpp")
|
||||
output_path = os.path.join(test_dir_path, "dummy_output.so")
|
||||
input_path = os.path.join(test_dir_path, "dummy_file.cpp")
|
||||
output_path = os.path.join(test_dir_path, "dummy_file.so")
|
||||
extra_flags = ["-D TEST_EXTRA_FLAGS"]
|
||||
if compile_only:
|
||||
output_path = os.path.join(test_dir_path, "dummy_output.o")
|
||||
output_path = os.path.join(test_dir_path, "dummy_file.o")
|
||||
picked_isa = pick_vec_isa()
|
||||
|
||||
# Simulate fb_code env:
|
||||
if not (aot_mode and not use_absolute_path):
|
||||
input_path = os.path.basename(input_path)
|
||||
output_path = os.path.basename(output_path)
|
||||
|
||||
# Fix test_new_cpp_build_logical failed on MacOS
|
||||
if sys.platform != "linux":
|
||||
aot_mode = False
|
||||
|
||||
old_cmd = cpp_compile_command(
|
||||
input=input_path,
|
||||
output=output_path,
|
||||
include_pytorch=include_pytorch,
|
||||
vec_isa=picked_isa,
|
||||
cuda=test_cuda,
|
||||
aot_mode=False,
|
||||
aot_mode=aot_mode,
|
||||
compile_only=compile_only,
|
||||
use_absolute_path=use_absolute_path,
|
||||
use_mmap_weights=mmap_weights,
|
||||
extra_flags=extra_flags,
|
||||
).split(" ")
|
||||
|
||||
name, dir = get_name_and_dir_from_output_file_path(input_path)
|
||||
|
||||
dummy_build_option = CppTorchCudaOptions(
|
||||
vec_isa=picked_isa,
|
||||
include_pytorch=include_pytorch,
|
||||
cuda=test_cuda,
|
||||
aot_mode=aot_mode,
|
||||
compile_only=compile_only,
|
||||
use_absolute_path=use_absolute_path,
|
||||
use_mmap_weights=mmap_weights,
|
||||
|
|
@ -2399,10 +2417,10 @@ def _do_validate_cpp_commands(
|
|||
)
|
||||
|
||||
dummy_builder = CppBuilder(
|
||||
name="dummy_output",
|
||||
name=name,
|
||||
sources=input_path,
|
||||
output_dir=dir,
|
||||
BuildOption=dummy_build_option,
|
||||
output_dir=test_dir_path,
|
||||
)
|
||||
new_cmd = dummy_builder.get_command_line().split(" ")
|
||||
|
||||
|
|
@ -2419,14 +2437,17 @@ def validate_new_cpp_commands():
|
|||
compile_only = [True, False]
|
||||
include_pytorch = [True, False]
|
||||
use_absolute_path = [True, False]
|
||||
aot_mode = [False, True]
|
||||
|
||||
for x in cuda:
|
||||
for y in use_mmap_weights:
|
||||
for z in compile_only:
|
||||
for m in include_pytorch:
|
||||
for n in use_absolute_path:
|
||||
for o in aot_mode:
|
||||
print(
|
||||
f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}, use_absolute_path:{n}"
|
||||
f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m},"
|
||||
f" use_absolute_path:{n}, aot_mode:{o}"
|
||||
)
|
||||
_do_validate_cpp_commands(
|
||||
include_pytorch=m,
|
||||
|
|
@ -2434,6 +2455,7 @@ def validate_new_cpp_commands():
|
|||
mmap_weights=y,
|
||||
compile_only=z,
|
||||
use_absolute_path=n,
|
||||
aot_mode=o,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -540,20 +540,6 @@ def _setup_standard_sys_libs(
|
|||
return cflags, include_dirs, passthough_args
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _cpp_prefix_path() -> str:
|
||||
from torch._inductor.codecache import write # TODO
|
||||
|
||||
path = Path(Path(__file__).parent).parent / "codegen/cpp_prefix.h"
|
||||
with path.open() as f:
|
||||
content = f.read()
|
||||
_, filename = write(
|
||||
content,
|
||||
"h",
|
||||
)
|
||||
return filename
|
||||
|
||||
|
||||
def _get_build_args_of_chosen_isa(vec_isa: VecISA):
|
||||
macros = []
|
||||
build_flags = []
|
||||
|
|
@ -939,14 +925,17 @@ def get_cpp_torch_cuda_options(cuda: bool, aot_mode: bool = False):
|
|||
libraries_dirs: List[str] = []
|
||||
libraries: List[str] = []
|
||||
passthough_args: List[str] = []
|
||||
|
||||
"""
|
||||
if (
|
||||
config.is_fbcode()
|
||||
and "CUDA_HOME" not in os.environ
|
||||
and "CUDA_PATH" not in os.environ
|
||||
):
|
||||
os.environ["CUDA_HOME"] = build_paths.cuda()
|
||||
"""
|
||||
from torch._inductor.codecache import _set_gpu_runtime_env, cpp_prefix_path
|
||||
|
||||
_set_gpu_runtime_env()
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
include_dirs = cpp_extension.include_paths(cuda)
|
||||
|
|
@ -971,7 +960,8 @@ def get_cpp_torch_cuda_options(cuda: bool, aot_mode: bool = False):
|
|||
libraries += ["c10_cuda", "cuda", "torch_cuda"]
|
||||
|
||||
if aot_mode:
|
||||
cpp_prefix_include_dir = [f"{os.path.dirname(_cpp_prefix_path())}"]
|
||||
if config.is_fbcode():
|
||||
cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"]
|
||||
include_dirs += cpp_prefix_include_dir
|
||||
|
||||
if cuda and torch.version.hip is None:
|
||||
|
|
@ -1061,15 +1051,26 @@ class CppTorchCudaOptions(CppTorchOptions):
|
|||
|
||||
|
||||
def get_name_and_dir_from_output_file_path(
|
||||
aot_mode: bool, use_absolute_path: bool, file_path: str
|
||||
file_path: str,
|
||||
):
|
||||
"""
|
||||
This function help prepare parameters to new cpp_builder.
|
||||
Example:
|
||||
input_code: /tmp/tmpof1n5g7t/5c/c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc.cpp
|
||||
name, dir = get_name_and_dir_from_output_file_path(input_code)
|
||||
Run result:
|
||||
name = c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc
|
||||
dir = /tmp/tmpof1n5g7t/5c/
|
||||
|
||||
put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'.
|
||||
CppBuilder --> get_target_file_path will format output path accoding OS:
|
||||
Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so
|
||||
Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll
|
||||
"""
|
||||
name_and_ext = os.path.basename(file_path)
|
||||
name, ext = os.path.splitext(name_and_ext)
|
||||
dir = os.path.dirname(file_path)
|
||||
|
||||
if config.is_fbcode():
|
||||
if not (aot_mode and not use_absolute_path):
|
||||
dir = "."
|
||||
return name, dir
|
||||
|
||||
|
||||
|
|
@ -1118,17 +1119,23 @@ class CppBuilder:
|
|||
self._target_file = ""
|
||||
|
||||
self._use_absolute_path: bool = False
|
||||
self._aot_mode: bool = False
|
||||
|
||||
self._name = name
|
||||
|
||||
# Code start here, initial self internal veriables firstly.
|
||||
self._compiler = BuildOption.get_compiler()
|
||||
self._use_absolute_path = BuildOption.get_use_absolute_path()
|
||||
self._aot_mode = BuildOption.get_aot_mode()
|
||||
|
||||
"""
|
||||
TODO: validate and remove:
|
||||
if len(output_dir) == 0:
|
||||
self._output_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
else:
|
||||
self._output_dir = output_dir
|
||||
"""
|
||||
self._output_dir = output_dir
|
||||
|
||||
self._compile_only = BuildOption.get_compile_only()
|
||||
file_ext = (
|
||||
|
|
@ -1142,7 +1149,7 @@ class CppBuilder:
|
|||
sources = [sources]
|
||||
|
||||
if config.is_fbcode():
|
||||
if BuildOption.get_aot_mode() and not self._use_absolute_path:
|
||||
if self._aot_mode and not self._use_absolute_path:
|
||||
inp_name = sources
|
||||
# output process @ get_name_and_dir_from_output_file_path
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user