mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[inductor] fix TritonTemplateCaller.__str__ (#97578)
We remove TritonTemplateCaller.to_callable previously. But this method is still used in `TritonTemplateCaller.__str__` . The to_callable method in the base class will be used and raise an exception. This PR fix TritonTemplateCaller.__str__ to return the string representation without calling to_callable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97578 Approved by: https://github.com/nmacchioni, https://github.com/ngimel
This commit is contained in:
parent
c905251f9f
commit
c681c52e01
|
|
@ -9,6 +9,8 @@ import torch._inductor.select_algorithm as select_algorithm
|
|||
import torch.nn.functional as F
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.autotune_process import BenchmarkRequest
|
||||
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
|
@ -273,6 +275,28 @@ class TestSelectAlgorithm(TestCase):
|
|||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
def test_TritonTemplateCaller_str(self):
|
||||
"""
|
||||
Make sure str(TritonTemplateCaller) does not raise exceptions.
|
||||
"""
|
||||
module_path = "abc.py"
|
||||
bmreq = BenchmarkRequest(
|
||||
module_path=module_path,
|
||||
module_cache_key=None,
|
||||
kernel_name=None,
|
||||
grid=None,
|
||||
extra_args=None,
|
||||
num_stages=None,
|
||||
num_warps=None,
|
||||
input_tensors=None,
|
||||
output_tensor=None,
|
||||
)
|
||||
caller = select_algorithm.TritonTemplateCaller(
|
||||
None, None, None, None, "extra", bmreq
|
||||
)
|
||||
caller_str = str(caller)
|
||||
self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
|
|
|||
|
|
@ -552,9 +552,7 @@ class TritonTemplateCaller(ChoiceCaller):
|
|||
return self.bmreq.benchmark(*args, output_tensor=out)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"TritonTemplateCaller({self.to_callable().__file__}, {self.debug_extra})"
|
||||
)
|
||||
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
|
||||
|
||||
def call_name(self):
|
||||
return f"template_kernels.{self.name}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user