mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[AOTI] Support multi-arch when using package_cpp_only (#154414)"
This reverts commit a84d8c4a1c.
Reverted https://github.com/pytorch/pytorch/pull/154414 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing ROCm trunk job ([comment](https://github.com/pytorch/pytorch/pull/154414#issuecomment-2915597821))
This commit is contained in:
parent
853958f82c
commit
fdc339003b
|
|
@ -20,12 +20,7 @@ from torch._inductor.package import AOTICompiledModel, load_package, package_aot
|
|||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfXpu, TEST_CUDA
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
|
|
@ -253,73 +248,6 @@ class TestAOTInductorPackage(TestCase):
|
|||
actual = optimized(*example_inputs)
|
||||
self.assertTrue(torch.allclose(actual, expected))
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfRocm # doesn't support multi-arch binary
|
||||
@skipIfXpu # doesn't support multi-arch binary
|
||||
def test_compile_after_package_multi_arch(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("Only meant to test GPU_TYPE")
|
||||
if not self.package_cpp_only:
|
||||
raise unittest.SkipTest("Only meant to test cpp package")
|
||||
if shutil.which("cmake") is None:
|
||||
raise unittest.SkipTest("cmake is not available")
|
||||
if shutil.which("make") is None:
|
||||
raise unittest.SkipTest("make is not available")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + self.linear(y)
|
||||
|
||||
with torch.no_grad():
|
||||
example_inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, 10, device=self.device),
|
||||
)
|
||||
model = Model().to(device=self.device)
|
||||
expected = model(*example_inputs)
|
||||
|
||||
options = {
|
||||
"aot_inductor.package_cpp_only": self.package_cpp_only,
|
||||
# Expect kernel to be embeded in the final binary.
|
||||
# We will make it the default behavior for the standalone mode.
|
||||
"aot_inductor.multi_arch_kernel_binary": True,
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
}
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep, inductor_configs=options
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
|
||||
package_path, "r"
|
||||
) as zip_ref:
|
||||
filenames = zip_ref.namelist()
|
||||
prefix = filenames[0].split("/")[0]
|
||||
zip_ref.extractall(tmp_dir)
|
||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
||||
self.assertTrue(tmp_path.exists())
|
||||
# Create a build directory to run cmake
|
||||
build_path = tmp_path / "build"
|
||||
build_path.mkdir()
|
||||
custom_env = os.environ.copy()
|
||||
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
||||
subprocess.run(
|
||||
["cmake", ".."],
|
||||
cwd=build_path,
|
||||
env=custom_env,
|
||||
)
|
||||
subprocess.run(["make"], cwd=build_path)
|
||||
|
||||
# Check if the .so file was build successfully
|
||||
so_path = build_path / "libaoti_model.so"
|
||||
self.assertTrue(so_path.exists())
|
||||
optimized = torch._export.aot_load(str(so_path), self.device)
|
||||
actual = optimized(*example_inputs)
|
||||
self.assertTrue(torch.allclose(actual, expected))
|
||||
|
||||
def test_metadata(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1975,7 +1975,7 @@ class AotCodeCompiler:
|
|||
)
|
||||
wrapper_build_options.save_flags_to_json(compile_flags)
|
||||
generated_files.append(compile_flags)
|
||||
wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type)
|
||||
wrapper_builder.save_compile_cmd_to_cmake(cmake_path)
|
||||
wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path)
|
||||
generated_files.append(cmake_path)
|
||||
else:
|
||||
|
|
@ -2134,8 +2134,8 @@ class AotCodeCompiler:
|
|||
so_builder.save_src_to_cmake(cmake_path, consts_o)
|
||||
|
||||
if config.aot_inductor.multi_arch_kernel_binary:
|
||||
so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files)
|
||||
generated_files.extend(asm_files)
|
||||
# TODO: support multi-arch when package_cpp_only
|
||||
pass
|
||||
else:
|
||||
obj_srcs = [*gpu_kernels_o, *cubins_o]
|
||||
generated_files.extend(obj_srcs)
|
||||
|
|
|
|||
|
|
@ -1719,13 +1719,7 @@ class CppBuilder:
|
|||
def save_compile_cmd_to_cmake(
|
||||
self,
|
||||
cmake_path: str,
|
||||
device_type: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save global cmake settings here, e.g. compiler options.
|
||||
If targeting CUDA, also emit a custom function to embed CUDA kernels.
|
||||
"""
|
||||
|
||||
definitions = " ".join(self._build_option.get_definitions())
|
||||
contents = textwrap.dedent(
|
||||
f"""
|
||||
|
|
@ -1749,68 +1743,6 @@ class CppBuilder:
|
|||
|
||||
"""
|
||||
)
|
||||
if device_type == "cuda":
|
||||
contents += textwrap.dedent(
|
||||
"""
|
||||
find_package(CUDA REQUIRED)
|
||||
|
||||
find_program(OBJCOPY_EXECUTABLE objcopy)
|
||||
if(NOT OBJCOPY_EXECUTABLE)
|
||||
message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file")
|
||||
endif()
|
||||
|
||||
set(KERNEL_TARGETS "")
|
||||
set(KERNEL_OBJECT_FILES "")
|
||||
# Function to embed a single kernel
|
||||
function(embed_gpu_kernel KERNEL_NAME PTX_FILE)
|
||||
set(FATBIN_BASENAME ${KERNEL_NAME}.fatbin)
|
||||
set(FATBIN_FILE ${CMAKE_CURRENT_BINARY_DIR}/${FATBIN_BASENAME})
|
||||
set(OBJECT_BASENAME ${KERNEL_NAME}.fatbin.o)
|
||||
set(OBJECT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${OBJECT_BASENAME})
|
||||
|
||||
# --- Define UNIQUE C symbol names ---
|
||||
set(SYMBOL_START __${KERNEL_NAME}_start)
|
||||
set(SYMBOL_END __${KERNEL_NAME}_end)
|
||||
set(SYMBOL_SIZE __${KERNEL_NAME}_size)
|
||||
string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${FATBIN_FILE})
|
||||
set(OBJCOPY_START_SYM _binary_${MANGLED_BASENAME}_start)
|
||||
set(OBJCOPY_END_SYM _binary_${MANGLED_BASENAME}_end)
|
||||
set(OBJCOPY_SIZE_SYM _binary_${MANGLED_BASENAME}_size)
|
||||
|
||||
# --- PTX to FATBIN Command & Target ---
|
||||
add_custom_command(
|
||||
OUTPUT ${FATBIN_FILE}
|
||||
COMMAND ${CUDA_NVCC_EXECUTABLE} --fatbin ${PTX_FILE} -o ${FATBIN_FILE} ${NVCC_GENCODE_FLAGS}
|
||||
-gencode arch=compute_80,code=compute_80
|
||||
-gencode arch=compute_86,code=compute_86
|
||||
-gencode arch=compute_89,code=compute_89
|
||||
-gencode arch=compute_90,code=compute_90
|
||||
DEPENDS ${PTX_FILE}
|
||||
)
|
||||
|
||||
# --- FATBIN to Object File (.o) Command ---
|
||||
add_custom_command(
|
||||
OUTPUT ${OBJECT_FILE}
|
||||
COMMAND ${CMAKE_LINKER} -r -b binary -z noexecstack -o ${OBJECT_FILE} ${FATBIN_FILE}
|
||||
COMMAND ${OBJCOPY_EXECUTABLE} --rename-section .data=.rodata,alloc,load,readonly,data,contents
|
||||
${OBJECT_FILE}
|
||||
COMMAND ${OBJCOPY_EXECUTABLE}
|
||||
--redefine-sym ${OBJCOPY_START_SYM}=${SYMBOL_START}
|
||||
--redefine-sym ${OBJCOPY_END_SYM}=${SYMBOL_END}
|
||||
--redefine-sym ${OBJCOPY_SIZE_SYM}=${SYMBOL_SIZE}
|
||||
${OBJECT_FILE}
|
||||
DEPENDS ${FATBIN_FILE}
|
||||
)
|
||||
add_custom_target(build_kernel_object_${KERNEL_NAME} DEPENDS ${OBJECT_FILE})
|
||||
|
||||
# --- Add to a list for linking later ---
|
||||
set(KERNEL_TARGETS ${KERNEL_TARGETS} build_kernel_object_${KERNEL_NAME} PARENT_SCOPE)
|
||||
set(KERNEL_OBJECT_FILES ${KERNEL_OBJECT_FILES} ${OBJECT_FILE} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
with open(cmake_path, "w") as f:
|
||||
f.write(contents)
|
||||
|
||||
|
|
@ -1820,23 +1752,6 @@ class CppBuilder:
|
|||
with open(cmake_path, "a") as f:
|
||||
f.write(f"target_sources(aoti_model PRIVATE {src_path})\n")
|
||||
|
||||
def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None:
|
||||
# TODO: make this work beyond CUDA
|
||||
with open(cmake_path, "a") as f:
|
||||
for asm_file in asm_files:
|
||||
kernel_name = Path(asm_file).name.split(".")[0]
|
||||
asm_file = f"${{CMAKE_CURRENT_SOURCE_DIR}}/{Path(asm_file).name}"
|
||||
contents = textwrap.dedent(
|
||||
f"""
|
||||
embed_gpu_kernel({kernel_name} {asm_file})
|
||||
"""
|
||||
)
|
||||
f.write(contents)
|
||||
f.write("add_dependencies(aoti_model ${KERNEL_TARGETS})\n")
|
||||
f.write(
|
||||
"target_link_libraries(aoti_model PRIVATE ${KERNEL_OBJECT_FILES})\n"
|
||||
)
|
||||
|
||||
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
|
||||
lflags = " ".join(self._build_option.get_ldflags())
|
||||
libs = " ".join(self._build_option.get_libraries())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user