[AOTI] add zero size consts asm handler (#159225)

Add `get_zero_consts_asm_code` to handle zero size consts to object.
This function is used to handle zero consts situation. Because cpp standard does not allow zero size array:
https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c
1. On Windows, MSVC will report error C2466:
https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466?view=msvc-170
So, we can use assmbely compiler to handle this situation.
2. On Windows, why not use Win32 asm to handle all path? Because ml64 only supports up to align `16`, it is
not aligned to pytorch's `64`. Reference: https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
```
Packs structures on the specified byte boundary. The alignment can be 1, 2, 4, 8, or 16.
```
3. It function can handle zero size case on both Windows and Linux, as that:
    A. On Linux, we added `-pedantic` to disable zero size array on C++ compiler. 8e07c9870d/torch/_inductor/cpp_builder.py (L580)
    B. On Windows, msvc is not support zero size array by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159225
Approved by: https://github.com/desertfire
This commit is contained in:
Xu Han 2025-07-31 22:46:33 +00:00 committed by PyTorch MergeBot
parent 490cb3f1a4
commit 7e00f2ec9d
3 changed files with 111 additions and 6 deletions

View File

@ -6740,6 +6740,25 @@ class AOTInductorLoggingTest(LoggingTestCase):
torch._inductor.aot_compile(ep.module(), inputs)
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
@make_logging_test(dynamic=logging.DEBUG)
def test_shape_env_reuse_zero_consts_use_consts_asm_false(self, records):
# make sure ShapeEnv is only created once and reused afterwards
class Foo(torch.nn.Module):
def forward(self, x):
return x + 2
inputs = (torch.randn(4, 4),)
dynamic_shapes = {
"x": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False)
with (
torch.no_grad(),
config.patch({"aot_inductor.use_consts_asm_build": False}),
):
torch._inductor.aot_compile(ep.module(), inputs)
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
class TestAOTInductorConfig(TestCase):
def test_no_compile_standalone(self):

View File

@ -75,6 +75,7 @@ from torch._inductor.cpp_builder import (
get_ld_and_objcopy,
get_name_and_dir_from_output_file_path,
normalize_path_separator,
run_asm_build_object,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.custom_graph_pass import (
@ -1862,8 +1863,9 @@ class AotCodeCompiler:
use_asm_build = False
is_large_consts = len(consts) > 1024
is_zero_size_consts = len(consts) == 0
def format_consts_to_asm(
def format_consts_to_gnu_asm(
consts: bytes,
align_bytes: int,
symbol_prefix: str,
@ -1912,14 +1914,65 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n"
return const_cpp, "cpp"
def get_zero_consts_asm_code(
align_bytes: int,
symbol_prefix: str,
) -> tuple[str, str]:
"""
This function handles zero-sized constants because the C++ standard prohibits zero-length arrays:
https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c
On Windows (MSVC):
The compiler reports error C2466 for zero-sized arrays:
https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466
Solution: Use assembly compilation to handle this case.
Why not use Win32 assembly for all paths?
ml64 only supports alignment up to 16 bytes, which isn't optimal for performance.
Cross-platform implementation:
Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler
Windows: MSVC naturally rejects zero-sized arrays by default
"""
if _IS_WINDOWS:
# Windows ml64 is max support align to 16, but it is no effect to zero size data.
asm_code = """
option casemap:none
.data
?_binary_constants_bin_start@@3PAEA:
align 16
?_binary_constants_bin_end@@3PAEA:
align 16
public ?_binary_constants_bin_start@@3PAEA
public ?_binary_constants_bin_end@@3PAEA
end
"""
asm_ext = "asm"
else:
asm_code = f"\t.section\t{section_attr}\n"
asm_code += f"\t.balign {align_bytes}\n"
asm_code += (
f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n"
)
asm_code += f"{symbol_prefix}_binary_constants_bin_start:\n"
asm_code += f".globl\t{symbol_prefix}_binary_constants_bin_end\n"
asm_code += f"{symbol_prefix}_binary_constants_bin_end:\n"
asm_ext = "S"
return asm_code, asm_ext
if use_asm_build:
consts_code, code_ext = format_consts_to_asm(
consts_code, code_ext = format_consts_to_gnu_asm(
consts, ALIGN_BYTES, symbol_prefix, is_large_consts
)
else:
consts_code, code_ext = format_consts_to_cpp(
consts, ALIGN_BYTES, symbol_prefix
)
if is_zero_size_consts:
consts_code, code_ext = get_zero_consts_asm_code(
ALIGN_BYTES, symbol_prefix
)
else:
consts_code, code_ext = format_consts_to_cpp(
consts, ALIGN_BYTES, symbol_prefix
)
_, consts_s = write(
consts_code,
@ -1940,7 +1993,10 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
BuildOption=object_build_options,
)
consts_o = object_builder.get_target_file_path()
object_builder.build()
if use_asm_build is False and is_zero_size_consts:
run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent))
else:
object_builder.build()
if is_large_consts and use_asm_build:
with open(consts_o, "r+b") as f:

View File

@ -1905,3 +1905,33 @@ class CppBuilder:
)
with open(cmake_path, "a") as f:
f.write(contents)
def run_asm_build_object(src: str, target: str, cwd: str) -> None:
def get_asm_compiler() -> str:
if _IS_WINDOWS:
ASM_CC = "ml64"
else:
ASM_CC = get_cpp_compiler()
# Intel compiler is not support to compile asm, switch to gcc.
if _is_intel_compiler(ASM_CC):
ASM_CC = "gcc"
return ASM_CC
def get_command_line(asm_cc: str, src: str, target: str) -> str:
if _IS_WINDOWS:
# Format reference:
# https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
cmd = f"{asm_cc} {src} /c /Fo {target}" # codespell:ignore /Fo
else:
cmd = f"{asm_cc} -c {src} -o {target}"
return cmd
asm_cc = get_asm_compiler()
cmd = get_command_line(
asm_cc=asm_cc,
src=normalize_path_separator(src),
target=normalize_path_separator(target),
)
run_compile_cmd(cmd, cwd=normalize_path_separator(cwd))