mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Add kernel_hash_key to ChoiceCaller (#154470)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154470 Approved by: https://github.com/mlazos
This commit is contained in:
parent
bd10ea4e6c
commit
7a79de1c0f
|
|
@ -605,7 +605,10 @@ class CUDATemplateCaller(ChoiceCaller):
|
|||
def call_name(self) -> str:
|
||||
return f"cuda_template_kernels.{self.name}"
|
||||
|
||||
def hash_key(self) -> str:
|
||||
def kernel_hash_key(self) -> str:
|
||||
"""
|
||||
Return kernel hash key that does not depend on swizzle.
|
||||
"""
|
||||
return "-".join(
|
||||
[
|
||||
self.category,
|
||||
|
|
@ -613,6 +616,17 @@ class CUDATemplateCaller(ChoiceCaller):
|
|||
]
|
||||
)
|
||||
|
||||
def hash_key(self) -> str:
|
||||
"""
|
||||
Return kernel hash key that does not depend on swizzle.
|
||||
"""
|
||||
return "-".join(
|
||||
[
|
||||
self.kernel_hash_key(),
|
||||
str(self.info_dict().get("swizzle")),
|
||||
]
|
||||
)
|
||||
|
||||
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
if self.info_kwargs is not None and "op" in self.info_kwargs:
|
||||
|
|
|
|||
|
|
@ -4709,6 +4709,13 @@ class ChoiceCaller:
|
|||
def to_callable(self): # type: ignore[no-untyped-def]
|
||||
raise NotImplementedError
|
||||
|
||||
def kernel_hash_key(self) -> str:
|
||||
"""
|
||||
Hash key for the underlying kernel. By default, we assume there are no
|
||||
runtime params, so kernel hash key defaults to choice caller's hash key.
|
||||
"""
|
||||
return self.hash_key()
|
||||
|
||||
def hash_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -2438,7 +2438,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
log.debug("Skipping already seen choice: %s", c)
|
||||
continue
|
||||
else:
|
||||
seen_choices.add(c.hash_key())
|
||||
seen_choices.add(c.kernel_hash_key())
|
||||
|
||||
if hasattr(c, "precompile"):
|
||||
triton_cuda_choice = isinstance(c, TritonTemplateCaller) and isinstance(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user