diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 1900c774399..0d01bf09166 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -156,6 +156,43 @@ class AOTInductorTestsTemplate: model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1 ) + @unittest.skipIf( + IS_FBCODE, + "toolchain doesn't support ptx to fatbin", + ) + @skipIfRocm + @skipIfXpu + @common_utils.parametrize("embed_kernel_binary", [True, False]) + def test_simple_multi_arch(self, embed_kernel_binary): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU_TYPE") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 16) + + def forward(self, x, y): + return x + self.linear(y) + + example_inputs = ( + torch.randn(10, 16, device=self.device), + torch.randn(10, 10, device=self.device), + ) + model = Model() + with config.patch( + { + "aot_inductor.embed_kernel_binary": embed_kernel_binary, + "aot_inductor.multi_arch_kernel_binary": True, + } + ): + self.check_model(model, example_inputs) + if not embed_kernel_binary: + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + FileCheck().check(".fatbin").run(code) + def test_small_constant(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2f9fb14c410..32fedafbcb2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -407,9 +407,9 @@ def get_path( def get_hash( content: Union[str, bytes], extra: str = "", hash_type: str = "code" ) -> str: - if hash_type == "code": + if hash_type in {"amdgcn", "code", "ptx"}: return code_hash(content, extra) - if hash_type in ["cubin", "hsaco", "spv"]: + if hash_type in {"cubin", "hsaco", "spv"}: return code_hash(repr(content)) raise AssertionError(f"Unknown hash type {hash_type}") @@ -420,11 +420,13 @@ def write( extra: str = "", hash_type: str = "code", specified_dir: str = "", + key: Optional[str] = None, ) -> tuple[str, str]: - # use striped content to compute hash so we don't end up with different - # hashes just because the content begins/ends with different number of - # spaces. - key: str = get_hash(content.strip(), extra, hash_type) + if key is None: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with different number of + # spaces. + key = get_hash(content.strip(), extra, hash_type) basename, _subdir, path = get_path(key, extension, specified_dir) if not os.path.exists(path): write_atomic(path, content, make_dirs=True) @@ -1544,28 +1546,62 @@ class CudaKernelParamCache: cache_clear = staticmethod(cache.clear) @classmethod - def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None: - _, path = write( + def set( + cls, + key: str, + params: dict[str, Optional[str]], + cubin: str, + bin_type: str, + asm: Optional[str] = None, + asm_type: Optional[str] = None, + ) -> None: + basename = None + if config.aot_inductor.package_cpp_only: + assert config.triton.unique_kernel_names, ( + "package_cpp_only requires triton kernel names to be unique" + ) + assert params["mangled_name"], "Missing kernel name" + basename = params["mangled_name"] + + _, bin_path = write( cubin, bin_type, hash_type=bin_type, specified_dir=split_aot_inductor_output_path( config.aot_inductor.output_path )[0], + key=basename, ) - if config.aot_inductor.package_cpp_only: - assert config.triton.unique_kernel_names, ( - "package_cpp_only requires triton kernel names to be unique" + # Retrieve the basename again in case it is a generated hashcode + basename, _ = get_name_and_dir_from_output_file_path(bin_path) + + if config.aot_inductor.multi_arch_kernel_binary: + assert bin_type == "cubin", ( + "multi_arch_kernel_binary only supported in CUDA" ) - dir_name = os.path.dirname(path) - _, ext = os.path.splitext(path) - # Construct the new full path - new_path = os.path.join(dir_name, params["mangled_name"] + ext) - os.rename(path, new_path) - path = new_path + base_path, _ = os.path.splitext(bin_path) + bin_path = base_path + ".fatbin" - params[get_cpp_wrapper_cubin_path_name()] = path + asm_path: str = "" + if ( + config.aot_inductor.multi_arch_kernel_binary + or config.aot_inductor.package_cpp_only + ): + assert asm, "Missing kernel assembly code" + assert asm_type, "Missing kernel assembly type" + _, asm_path = write( + asm, + asm_type, + hash_type=asm_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + # make sure asm file has the same basename + key=basename, + ) + params[get_cpp_wrapper_cubin_path_name()] = bin_path + params["asm"] = asm_path cls.cache[key] = params @classmethod @@ -2007,13 +2043,33 @@ class AotCodeCompiler: for entry in gpu_codecache.cache.values() if entry.output_path.endswith(".o") ] + if gpu_kernels_o: + assert not config.aot_inductor.multi_arch_kernel_binary, ( + "TODO: add multi_arch_kernel_binary support for cutlass kernels" + ) cubins_o = [] - if config.aot_inductor.embed_kernel_binary: - # Embed cubin files into .so using objcopy - ld, objcopy = get_ld_and_objcopy(use_relative_path) - for kernel_name, value in CudaKernelParamCache.cache.items(): - cubin_file = value[get_cpp_wrapper_cubin_path_name()] + asm_files = [] + ld, objcopy = get_ld_and_objcopy(use_relative_path) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if config.aot_inductor.multi_arch_kernel_binary: + # Compile .ptx into .fatbin + archs = OrderedSet( + [cuda_env.get_cuda_arch(), "80", "86", "89", "90"] + ) + cmd = f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file}" + for arch in archs: + cmd += f" -gencode arch=compute_{arch},code=compute_{arch}" + subprocess.run( + cmd.split(), capture_output=True, text=True, check=True + ) + + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy cubins_o.append( convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) ) @@ -2061,7 +2117,6 @@ class AotCodeCompiler: # If we only want to package the cpp, then we need to save the # weights separately into a bin, and we also need to prevent compiling the so - if use_mmap_weights: weight_file = str( wrapper_path_operator.with_name( @@ -2073,11 +2128,20 @@ class AotCodeCompiler: f_weights.write(struct.pack("q", magic_number)) generated_files.append(weight_file) + else: + # TODO: unify to alway use mmap_weights + generated_files.append(consts_o) + so_builder.save_src_to_cmake(cmake_path, consts_o) + + if config.aot_inductor.multi_arch_kernel_binary: + # TODO: support multi-arch when package_cpp_only + pass + else: + obj_srcs = [*gpu_kernels_o, *cubins_o] + generated_files.extend(obj_srcs) + for obj in obj_srcs: + so_builder.save_src_to_cmake(cmake_path, obj) - obj_srcs = [consts_o, *gpu_kernels_o, *cubins_o] - generated_files.extend(obj_srcs) - for obj in obj_srcs: - so_builder.save_src_to_cmake(cmake_path, obj) so_builder.save_link_cmd_to_cmake(cmake_path) else: so_builder.build() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 547a0d6568e..f4320b2a63b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1330,6 +1330,9 @@ class aot_inductor: # Embed generated kernel binary files into model.so embed_kernel_binary: bool = False + # Generate kernel binary files that support multiple archs + multi_arch_kernel_binary: bool = False + # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {} # custom op libs that have implemented C shim wrappers diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 64bc8580564..10df62f8e9f 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -182,11 +182,11 @@ def convert_cubin_to_obj( obj_file = cubin_file + ".o" # Convert .cubin to .o cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}" - subprocess.run(cmd.split(), capture_output=True, text=True) + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) os.remove(cubin_file) # Rename .data to .rodata cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}" - subprocess.run(cmd.split(), capture_output=True, text=True) + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) # By default objcopy will create *_start, *_size, *_end symbols using the full path # Rename to use the unique kernel name file_name = re.sub(r"[\W]", "_", cubin_file) @@ -197,7 +197,7 @@ def convert_cubin_to_obj( + f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end " + obj_file ) - subprocess.run(cmd.split(), capture_output=True, text=True) + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) return obj_file diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ba94ac18731..172a8620354 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1002,8 +1002,11 @@ class CachingAutotuner(KernelInterface): bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") binary = launcher.bin.asm[bin_type] - CudaKernelParamCache.set(key, params, binary, bin_type) + # Also store asm code which can be used for debugging and generating cpp package + asm_type = {"hip": "amdgcn", "cuda": "ptx"}.get(self.device_props.type, None) + asm = launcher.bin.asm.get(asm_type, None) + CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) self.cuda_kernel_saved = True def coordinate_descent_tuning(self, launcher, *args, **kwargs):