mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI] Add a multi_arch_kernel_binary option (#154413)
Summary: CUDA can support multi-arch with the fatbin format. Add this multi_arch_kernel_binary option, so the compiled model binary can run across different GPU archs. Differential Revision: [D75452094](https://our.internmc.facebook.com/intern/diff/D75452094) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154413 Approved by: https://github.com/angelayi ghstack dependencies: #154412
This commit is contained in:
parent
4d8f3d537a
commit
cde82d25b7
|
|
@ -156,6 +156,43 @@ class AOTInductorTestsTemplate:
|
||||||
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
|
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
IS_FBCODE,
|
||||||
|
"toolchain doesn't support ptx to fatbin",
|
||||||
|
)
|
||||||
|
@skipIfRocm
|
||||||
|
@skipIfXpu
|
||||||
|
@common_utils.parametrize("embed_kernel_binary", [True, False])
|
||||||
|
def test_simple_multi_arch(self, embed_kernel_binary):
|
||||||
|
if self.device != GPU_TYPE:
|
||||||
|
raise unittest.SkipTest("requires GPU_TYPE")
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(10, 16)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x + self.linear(y)
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn(10, 16, device=self.device),
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
)
|
||||||
|
model = Model()
|
||||||
|
with config.patch(
|
||||||
|
{
|
||||||
|
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
|
||||||
|
"aot_inductor.multi_arch_kernel_binary": True,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
self.check_model(model, example_inputs)
|
||||||
|
if not embed_kernel_binary:
|
||||||
|
_, code = run_and_get_cpp_code(
|
||||||
|
AOTIRunnerUtil.compile, model, example_inputs
|
||||||
|
)
|
||||||
|
FileCheck().check(".fatbin").run(code)
|
||||||
|
|
||||||
def test_small_constant(self):
|
def test_small_constant(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -407,9 +407,9 @@ def get_path(
|
||||||
def get_hash(
|
def get_hash(
|
||||||
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
|
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
|
||||||
) -> str:
|
) -> str:
|
||||||
if hash_type == "code":
|
if hash_type in {"amdgcn", "code", "ptx"}:
|
||||||
return code_hash(content, extra)
|
return code_hash(content, extra)
|
||||||
if hash_type in ["cubin", "hsaco", "spv"]:
|
if hash_type in {"cubin", "hsaco", "spv"}:
|
||||||
return code_hash(repr(content))
|
return code_hash(repr(content))
|
||||||
raise AssertionError(f"Unknown hash type {hash_type}")
|
raise AssertionError(f"Unknown hash type {hash_type}")
|
||||||
|
|
||||||
|
|
@ -420,11 +420,13 @@ def write(
|
||||||
extra: str = "",
|
extra: str = "",
|
||||||
hash_type: str = "code",
|
hash_type: str = "code",
|
||||||
specified_dir: str = "",
|
specified_dir: str = "",
|
||||||
|
key: Optional[str] = None,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
# use striped content to compute hash so we don't end up with different
|
if key is None:
|
||||||
# hashes just because the content begins/ends with different number of
|
# use striped content to compute hash so we don't end up with different
|
||||||
# spaces.
|
# hashes just because the content begins/ends with different number of
|
||||||
key: str = get_hash(content.strip(), extra, hash_type)
|
# spaces.
|
||||||
|
key = get_hash(content.strip(), extra, hash_type)
|
||||||
basename, _subdir, path = get_path(key, extension, specified_dir)
|
basename, _subdir, path = get_path(key, extension, specified_dir)
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
write_atomic(path, content, make_dirs=True)
|
write_atomic(path, content, make_dirs=True)
|
||||||
|
|
@ -1544,28 +1546,62 @@ class CudaKernelParamCache:
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
|
def set(
|
||||||
_, path = write(
|
cls,
|
||||||
|
key: str,
|
||||||
|
params: dict[str, Optional[str]],
|
||||||
|
cubin: str,
|
||||||
|
bin_type: str,
|
||||||
|
asm: Optional[str] = None,
|
||||||
|
asm_type: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
basename = None
|
||||||
|
if config.aot_inductor.package_cpp_only:
|
||||||
|
assert config.triton.unique_kernel_names, (
|
||||||
|
"package_cpp_only requires triton kernel names to be unique"
|
||||||
|
)
|
||||||
|
assert params["mangled_name"], "Missing kernel name"
|
||||||
|
basename = params["mangled_name"]
|
||||||
|
|
||||||
|
_, bin_path = write(
|
||||||
cubin,
|
cubin,
|
||||||
bin_type,
|
bin_type,
|
||||||
hash_type=bin_type,
|
hash_type=bin_type,
|
||||||
specified_dir=split_aot_inductor_output_path(
|
specified_dir=split_aot_inductor_output_path(
|
||||||
config.aot_inductor.output_path
|
config.aot_inductor.output_path
|
||||||
)[0],
|
)[0],
|
||||||
|
key=basename,
|
||||||
)
|
)
|
||||||
if config.aot_inductor.package_cpp_only:
|
# Retrieve the basename again in case it is a generated hashcode
|
||||||
assert config.triton.unique_kernel_names, (
|
basename, _ = get_name_and_dir_from_output_file_path(bin_path)
|
||||||
"package_cpp_only requires triton kernel names to be unique"
|
|
||||||
|
if config.aot_inductor.multi_arch_kernel_binary:
|
||||||
|
assert bin_type == "cubin", (
|
||||||
|
"multi_arch_kernel_binary only supported in CUDA"
|
||||||
)
|
)
|
||||||
dir_name = os.path.dirname(path)
|
base_path, _ = os.path.splitext(bin_path)
|
||||||
_, ext = os.path.splitext(path)
|
bin_path = base_path + ".fatbin"
|
||||||
# Construct the new full path
|
|
||||||
new_path = os.path.join(dir_name, params["mangled_name"] + ext)
|
|
||||||
os.rename(path, new_path)
|
|
||||||
path = new_path
|
|
||||||
|
|
||||||
params[get_cpp_wrapper_cubin_path_name()] = path
|
asm_path: str = ""
|
||||||
|
if (
|
||||||
|
config.aot_inductor.multi_arch_kernel_binary
|
||||||
|
or config.aot_inductor.package_cpp_only
|
||||||
|
):
|
||||||
|
assert asm, "Missing kernel assembly code"
|
||||||
|
assert asm_type, "Missing kernel assembly type"
|
||||||
|
_, asm_path = write(
|
||||||
|
asm,
|
||||||
|
asm_type,
|
||||||
|
hash_type=asm_type,
|
||||||
|
specified_dir=split_aot_inductor_output_path(
|
||||||
|
config.aot_inductor.output_path
|
||||||
|
)[0],
|
||||||
|
# make sure asm file has the same basename
|
||||||
|
key=basename,
|
||||||
|
)
|
||||||
|
|
||||||
|
params[get_cpp_wrapper_cubin_path_name()] = bin_path
|
||||||
|
params["asm"] = asm_path
|
||||||
cls.cache[key] = params
|
cls.cache[key] = params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -2007,13 +2043,33 @@ class AotCodeCompiler:
|
||||||
for entry in gpu_codecache.cache.values()
|
for entry in gpu_codecache.cache.values()
|
||||||
if entry.output_path.endswith(".o")
|
if entry.output_path.endswith(".o")
|
||||||
]
|
]
|
||||||
|
if gpu_kernels_o:
|
||||||
|
assert not config.aot_inductor.multi_arch_kernel_binary, (
|
||||||
|
"TODO: add multi_arch_kernel_binary support for cutlass kernels"
|
||||||
|
)
|
||||||
|
|
||||||
cubins_o = []
|
cubins_o = []
|
||||||
if config.aot_inductor.embed_kernel_binary:
|
asm_files = []
|
||||||
# Embed cubin files into .so using objcopy
|
ld, objcopy = get_ld_and_objcopy(use_relative_path)
|
||||||
ld, objcopy = get_ld_and_objcopy(use_relative_path)
|
for kernel_name, value in CudaKernelParamCache.cache.items():
|
||||||
for kernel_name, value in CudaKernelParamCache.cache.items():
|
if asm_file := value["asm"]:
|
||||||
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
|
asm_files.append(asm_file)
|
||||||
|
|
||||||
|
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
|
||||||
|
if config.aot_inductor.multi_arch_kernel_binary:
|
||||||
|
# Compile .ptx into .fatbin
|
||||||
|
archs = OrderedSet(
|
||||||
|
[cuda_env.get_cuda_arch(), "80", "86", "89", "90"]
|
||||||
|
)
|
||||||
|
cmd = f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file}"
|
||||||
|
for arch in archs:
|
||||||
|
cmd += f" -gencode arch=compute_{arch},code=compute_{arch}"
|
||||||
|
subprocess.run(
|
||||||
|
cmd.split(), capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.aot_inductor.embed_kernel_binary:
|
||||||
|
# Embed cubin files into model.so using objcopy
|
||||||
cubins_o.append(
|
cubins_o.append(
|
||||||
convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy)
|
convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy)
|
||||||
)
|
)
|
||||||
|
|
@ -2061,7 +2117,6 @@ class AotCodeCompiler:
|
||||||
|
|
||||||
# If we only want to package the cpp, then we need to save the
|
# If we only want to package the cpp, then we need to save the
|
||||||
# weights separately into a bin, and we also need to prevent compiling the so
|
# weights separately into a bin, and we also need to prevent compiling the so
|
||||||
|
|
||||||
if use_mmap_weights:
|
if use_mmap_weights:
|
||||||
weight_file = str(
|
weight_file = str(
|
||||||
wrapper_path_operator.with_name(
|
wrapper_path_operator.with_name(
|
||||||
|
|
@ -2073,11 +2128,20 @@ class AotCodeCompiler:
|
||||||
f_weights.write(struct.pack("q", magic_number))
|
f_weights.write(struct.pack("q", magic_number))
|
||||||
|
|
||||||
generated_files.append(weight_file)
|
generated_files.append(weight_file)
|
||||||
|
else:
|
||||||
|
# TODO: unify to alway use mmap_weights
|
||||||
|
generated_files.append(consts_o)
|
||||||
|
so_builder.save_src_to_cmake(cmake_path, consts_o)
|
||||||
|
|
||||||
|
if config.aot_inductor.multi_arch_kernel_binary:
|
||||||
|
# TODO: support multi-arch when package_cpp_only
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
obj_srcs = [*gpu_kernels_o, *cubins_o]
|
||||||
|
generated_files.extend(obj_srcs)
|
||||||
|
for obj in obj_srcs:
|
||||||
|
so_builder.save_src_to_cmake(cmake_path, obj)
|
||||||
|
|
||||||
obj_srcs = [consts_o, *gpu_kernels_o, *cubins_o]
|
|
||||||
generated_files.extend(obj_srcs)
|
|
||||||
for obj in obj_srcs:
|
|
||||||
so_builder.save_src_to_cmake(cmake_path, obj)
|
|
||||||
so_builder.save_link_cmd_to_cmake(cmake_path)
|
so_builder.save_link_cmd_to_cmake(cmake_path)
|
||||||
else:
|
else:
|
||||||
so_builder.build()
|
so_builder.build()
|
||||||
|
|
|
||||||
|
|
@ -1330,6 +1330,9 @@ class aot_inductor:
|
||||||
# Embed generated kernel binary files into model.so
|
# Embed generated kernel binary files into model.so
|
||||||
embed_kernel_binary: bool = False
|
embed_kernel_binary: bool = False
|
||||||
|
|
||||||
|
# Generate kernel binary files that support multiple archs
|
||||||
|
multi_arch_kernel_binary: bool = False
|
||||||
|
|
||||||
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
||||||
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
|
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
|
||||||
# custom op libs that have implemented C shim wrappers
|
# custom op libs that have implemented C shim wrappers
|
||||||
|
|
|
||||||
|
|
@ -182,11 +182,11 @@ def convert_cubin_to_obj(
|
||||||
obj_file = cubin_file + ".o"
|
obj_file = cubin_file + ".o"
|
||||||
# Convert .cubin to .o
|
# Convert .cubin to .o
|
||||||
cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}"
|
cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}"
|
||||||
subprocess.run(cmd.split(), capture_output=True, text=True)
|
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
|
||||||
os.remove(cubin_file)
|
os.remove(cubin_file)
|
||||||
# Rename .data to .rodata
|
# Rename .data to .rodata
|
||||||
cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}"
|
cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}"
|
||||||
subprocess.run(cmd.split(), capture_output=True, text=True)
|
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
|
||||||
# By default objcopy will create *_start, *_size, *_end symbols using the full path
|
# By default objcopy will create *_start, *_size, *_end symbols using the full path
|
||||||
# Rename to use the unique kernel name
|
# Rename to use the unique kernel name
|
||||||
file_name = re.sub(r"[\W]", "_", cubin_file)
|
file_name = re.sub(r"[\W]", "_", cubin_file)
|
||||||
|
|
@ -197,7 +197,7 @@ def convert_cubin_to_obj(
|
||||||
+ f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end "
|
+ f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end "
|
||||||
+ obj_file
|
+ obj_file
|
||||||
)
|
)
|
||||||
subprocess.run(cmd.split(), capture_output=True, text=True)
|
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
|
||||||
return obj_file
|
return obj_file
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1002,8 +1002,11 @@ class CachingAutotuner(KernelInterface):
|
||||||
|
|
||||||
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
|
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
|
||||||
binary = launcher.bin.asm[bin_type]
|
binary = launcher.bin.asm[bin_type]
|
||||||
CudaKernelParamCache.set(key, params, binary, bin_type)
|
# Also store asm code which can be used for debugging and generating cpp package
|
||||||
|
asm_type = {"hip": "amdgcn", "cuda": "ptx"}.get(self.device_props.type, None)
|
||||||
|
asm = launcher.bin.asm.get(asm_type, None)
|
||||||
|
|
||||||
|
CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
|
||||||
self.cuda_kernel_saved = True
|
self.cuda_kernel_saved = True
|
||||||
|
|
||||||
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user