[AOTI] Add a multi_arch_kernel_binary option (#154413)

Summary: CUDA can support multi-arch with the fatbin format. Add this multi_arch_kernel_binary option, so the compiled model binary can run across different GPU archs.

Differential Revision: [D75452094](https://our.internmc.facebook.com/intern/diff/D75452094)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154413
Approved by: https://github.com/angelayi
ghstack dependencies: #154412
This commit is contained in:
Bin Bao 2025-05-27 10:34:59 -07:00 committed by PyTorch MergeBot
parent 4d8f3d537a
commit cde82d25b7
5 changed files with 139 additions and 32 deletions

View File

@ -156,6 +156,43 @@ class AOTInductorTestsTemplate:
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
)
@unittest.skipIf(
IS_FBCODE,
"toolchain doesn't support ptx to fatbin",
)
@skipIfRocm
@skipIfXpu
@common_utils.parametrize("embed_kernel_binary", [True, False])
def test_simple_multi_arch(self, embed_kernel_binary):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU_TYPE")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 16)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 16, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model()
with config.patch(
{
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
"aot_inductor.multi_arch_kernel_binary": True,
}
):
self.check_model(model, example_inputs)
if not embed_kernel_binary:
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
FileCheck().check(".fatbin").run(code)
def test_small_constant(self):
class Model(torch.nn.Module):
def __init__(self) -> None:

View File

@ -407,9 +407,9 @@ def get_path(
def get_hash(
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
) -> str:
if hash_type == "code":
if hash_type in {"amdgcn", "code", "ptx"}:
return code_hash(content, extra)
if hash_type in ["cubin", "hsaco", "spv"]:
if hash_type in {"cubin", "hsaco", "spv"}:
return code_hash(repr(content))
raise AssertionError(f"Unknown hash type {hash_type}")
@ -420,11 +420,13 @@ def write(
extra: str = "",
hash_type: str = "code",
specified_dir: str = "",
key: Optional[str] = None,
) -> tuple[str, str]:
# use striped content to compute hash so we don't end up with different
# hashes just because the content begins/ends with different number of
# spaces.
key: str = get_hash(content.strip(), extra, hash_type)
if key is None:
# use striped content to compute hash so we don't end up with different
# hashes just because the content begins/ends with different number of
# spaces.
key = get_hash(content.strip(), extra, hash_type)
basename, _subdir, path = get_path(key, extension, specified_dir)
if not os.path.exists(path):
write_atomic(path, content, make_dirs=True)
@ -1544,28 +1546,62 @@ class CudaKernelParamCache:
cache_clear = staticmethod(cache.clear)
@classmethod
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
_, path = write(
def set(
cls,
key: str,
params: dict[str, Optional[str]],
cubin: str,
bin_type: str,
asm: Optional[str] = None,
asm_type: Optional[str] = None,
) -> None:
basename = None
if config.aot_inductor.package_cpp_only:
assert config.triton.unique_kernel_names, (
"package_cpp_only requires triton kernel names to be unique"
)
assert params["mangled_name"], "Missing kernel name"
basename = params["mangled_name"]
_, bin_path = write(
cubin,
bin_type,
hash_type=bin_type,
specified_dir=split_aot_inductor_output_path(
config.aot_inductor.output_path
)[0],
key=basename,
)
if config.aot_inductor.package_cpp_only:
assert config.triton.unique_kernel_names, (
"package_cpp_only requires triton kernel names to be unique"
# Retrieve the basename again in case it is a generated hashcode
basename, _ = get_name_and_dir_from_output_file_path(bin_path)
if config.aot_inductor.multi_arch_kernel_binary:
assert bin_type == "cubin", (
"multi_arch_kernel_binary only supported in CUDA"
)
dir_name = os.path.dirname(path)
_, ext = os.path.splitext(path)
# Construct the new full path
new_path = os.path.join(dir_name, params["mangled_name"] + ext)
os.rename(path, new_path)
path = new_path
base_path, _ = os.path.splitext(bin_path)
bin_path = base_path + ".fatbin"
params[get_cpp_wrapper_cubin_path_name()] = path
asm_path: str = ""
if (
config.aot_inductor.multi_arch_kernel_binary
or config.aot_inductor.package_cpp_only
):
assert asm, "Missing kernel assembly code"
assert asm_type, "Missing kernel assembly type"
_, asm_path = write(
asm,
asm_type,
hash_type=asm_type,
specified_dir=split_aot_inductor_output_path(
config.aot_inductor.output_path
)[0],
# make sure asm file has the same basename
key=basename,
)
params[get_cpp_wrapper_cubin_path_name()] = bin_path
params["asm"] = asm_path
cls.cache[key] = params
@classmethod
@ -2007,13 +2043,33 @@ class AotCodeCompiler:
for entry in gpu_codecache.cache.values()
if entry.output_path.endswith(".o")
]
if gpu_kernels_o:
assert not config.aot_inductor.multi_arch_kernel_binary, (
"TODO: add multi_arch_kernel_binary support for cutlass kernels"
)
cubins_o = []
if config.aot_inductor.embed_kernel_binary:
# Embed cubin files into .so using objcopy
ld, objcopy = get_ld_and_objcopy(use_relative_path)
for kernel_name, value in CudaKernelParamCache.cache.items():
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
asm_files = []
ld, objcopy = get_ld_and_objcopy(use_relative_path)
for kernel_name, value in CudaKernelParamCache.cache.items():
if asm_file := value["asm"]:
asm_files.append(asm_file)
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
if config.aot_inductor.multi_arch_kernel_binary:
# Compile .ptx into .fatbin
archs = OrderedSet(
[cuda_env.get_cuda_arch(), "80", "86", "89", "90"]
)
cmd = f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file}"
for arch in archs:
cmd += f" -gencode arch=compute_{arch},code=compute_{arch}"
subprocess.run(
cmd.split(), capture_output=True, text=True, check=True
)
if config.aot_inductor.embed_kernel_binary:
# Embed cubin files into model.so using objcopy
cubins_o.append(
convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy)
)
@ -2061,7 +2117,6 @@ class AotCodeCompiler:
# If we only want to package the cpp, then we need to save the
# weights separately into a bin, and we also need to prevent compiling the so
if use_mmap_weights:
weight_file = str(
wrapper_path_operator.with_name(
@ -2073,11 +2128,20 @@ class AotCodeCompiler:
f_weights.write(struct.pack("q", magic_number))
generated_files.append(weight_file)
else:
# TODO: unify to alway use mmap_weights
generated_files.append(consts_o)
so_builder.save_src_to_cmake(cmake_path, consts_o)
if config.aot_inductor.multi_arch_kernel_binary:
# TODO: support multi-arch when package_cpp_only
pass
else:
obj_srcs = [*gpu_kernels_o, *cubins_o]
generated_files.extend(obj_srcs)
for obj in obj_srcs:
so_builder.save_src_to_cmake(cmake_path, obj)
obj_srcs = [consts_o, *gpu_kernels_o, *cubins_o]
generated_files.extend(obj_srcs)
for obj in obj_srcs:
so_builder.save_src_to_cmake(cmake_path, obj)
so_builder.save_link_cmd_to_cmake(cmake_path)
else:
so_builder.build()

View File

@ -1330,6 +1330,9 @@ class aot_inductor:
# Embed generated kernel binary files into model.so
embed_kernel_binary: bool = False
# Generate kernel binary files that support multiple archs
multi_arch_kernel_binary: bool = False
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
# custom op libs that have implemented C shim wrappers

View File

@ -182,11 +182,11 @@ def convert_cubin_to_obj(
obj_file = cubin_file + ".o"
# Convert .cubin to .o
cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}"
subprocess.run(cmd.split(), capture_output=True, text=True)
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
os.remove(cubin_file)
# Rename .data to .rodata
cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}"
subprocess.run(cmd.split(), capture_output=True, text=True)
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
# By default objcopy will create *_start, *_size, *_end symbols using the full path
# Rename to use the unique kernel name
file_name = re.sub(r"[\W]", "_", cubin_file)
@ -197,7 +197,7 @@ def convert_cubin_to_obj(
+ f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end "
+ obj_file
)
subprocess.run(cmd.split(), capture_output=True, text=True)
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
return obj_file

View File

@ -1002,8 +1002,11 @@ class CachingAutotuner(KernelInterface):
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
binary = launcher.bin.asm[bin_type]
CudaKernelParamCache.set(key, params, binary, bin_type)
# Also store asm code which can be used for debugging and generating cpp package
asm_type = {"hip": "amdgcn", "cuda": "ptx"}.get(self.device_props.type, None)
asm = launcher.bin.asm.get(asm_type, None)
CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
self.cuda_kernel_saved = True
def coordinate_descent_tuning(self, launcher, *args, **kwargs):