[inductor] Add kernel_hash_key to ChoiceCaller (#154470)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154470
Approved by: https://github.com/mlazos
This commit is contained in:
henrylhtsang 2025-05-27 17:48:11 -07:00 committed by PyTorch MergeBot
parent bd10ea4e6c
commit 7a79de1c0f
3 changed files with 23 additions and 2 deletions

View File

@ -605,7 +605,10 @@ class CUDATemplateCaller(ChoiceCaller):
def call_name(self) -> str:
return f"cuda_template_kernels.{self.name}"
def hash_key(self) -> str:
def kernel_hash_key(self) -> str:
"""
Return kernel hash key that does not depend on swizzle.
"""
return "-".join(
[
self.category,
@ -613,6 +616,17 @@ class CUDATemplateCaller(ChoiceCaller):
]
)
def hash_key(self) -> str:
"""
Return kernel hash key that does not depend on swizzle.
"""
return "-".join(
[
self.kernel_hash_key(),
str(self.info_dict().get("swizzle")),
]
)
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
if self.info_kwargs is not None and "op" in self.info_kwargs:

View File

@ -4709,6 +4709,13 @@ class ChoiceCaller:
def to_callable(self): # type: ignore[no-untyped-def]
raise NotImplementedError
def kernel_hash_key(self) -> str:
"""
Hash key for the underlying kernel. By default, we assume there are no
runtime params, so kernel hash key defaults to choice caller's hash key.
"""
return self.hash_key()
def hash_key(self) -> str:
raise NotImplementedError

View File

@ -2438,7 +2438,7 @@ class AlgorithmSelectorCache(PersistentCache):
log.debug("Skipping already seen choice: %s", c)
continue
else:
seen_choices.add(c.hash_key())
seen_choices.add(c.kernel_hash_key())
if hasattr(c, "precompile"):
triton_cuda_choice = isinstance(c, TritonTemplateCaller) and isinstance(