mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[AOTI] Generate unique cubin file names when package_cpp_only (#153948)
Summary: * When package_cpp_only is specified, generate kernel file names with unique kernel names to make the final packaged package files more readable. Assert on unique_kernel_names in case somehow it was explicitly set to False. * Fix a rocm test skip, see https://github.com/pytorch/pytorch/pull/153828 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153948 Approved by: https://github.com/angelayi, https://github.com/yushangdi
This commit is contained in:
parent
8cabd23b3d
commit
2c2524f74b
|
|
@ -15,16 +15,12 @@ from typing import Callable
|
|||
from parameterized import parameterized_class
|
||||
|
||||
import torch
|
||||
from torch._inductor.codecache import get_kernel_bin_format
|
||||
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfXpu, TEST_CUDA
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
|
|
@ -183,7 +179,6 @@ class TestAOTInductorPackage(TestCase):
|
|||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfRocm # build system may be different
|
||||
@skipIfXpu # build system may be different
|
||||
def test_compile_after_package(self):
|
||||
if not self.package_cpp_only:
|
||||
|
|
@ -225,8 +220,10 @@ class TestAOTInductorPackage(TestCase):
|
|||
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
|
||||
self.assertTrue(tmp_path.exists())
|
||||
if self.device == GPU_TYPE:
|
||||
self.assertTrue(not list(tmp_path.glob("*.cubin")))
|
||||
self.assertTrue(list(tmp_path.glob("*.cubin.o")))
|
||||
kernel_bin = get_kernel_bin_format(self.device)
|
||||
self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}")))
|
||||
# Check if .cubin.o files exist and use unique kernel names
|
||||
self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o")))
|
||||
|
||||
build_path = tmp_path / "build"
|
||||
self.assertTrue(not build_path.exists())
|
||||
|
|
|
|||
|
|
@ -162,6 +162,15 @@ def get_cpp_wrapper_cubin_path_name() -> str:
|
|||
return "cubin_path" if torch.version.hip is None else "hsaco_path"
|
||||
|
||||
|
||||
def get_kernel_bin_format(device: str) -> str:
|
||||
if device == "cuda":
|
||||
return "cubin" if torch.version.hip is None else "hsaco"
|
||||
elif device == "xpu":
|
||||
return "spv"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
|
||||
return (
|
||||
|
|
@ -1527,6 +1536,17 @@ class CudaKernelParamCache:
|
|||
config.aot_inductor.output_path
|
||||
)[0],
|
||||
)
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
assert config.triton.unique_kernel_names, (
|
||||
"package_cpp_only requires triton kernel names to be unique"
|
||||
)
|
||||
dir_name = os.path.dirname(path)
|
||||
_, ext = os.path.splitext(path)
|
||||
# Construct the new full path
|
||||
new_path = os.path.join(dir_name, params["mangled_name"] + ext)
|
||||
os.rename(path, new_path)
|
||||
path = new_path
|
||||
|
||||
params[get_cpp_wrapper_cubin_path_name()] = path
|
||||
|
||||
cls.cache[key] = params
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user