[inductor] fix TritonTemplateCaller.__str__ (#97578)

We remove TritonTemplateCaller.to_callable previously. But this method is still used in `TritonTemplateCaller.__str__` . The to_callable method in the base class will be used and raise an exception.

This PR fix TritonTemplateCaller.__str__ to return the string representation without calling to_callable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97578
Approved by: https://github.com/nmacchioni, https://github.com/ngimel
This commit is contained in:
Shunting Zhang 2023-03-28 23:33:52 +00:00 committed by PyTorch MergeBot
parent c905251f9f
commit c681c52e01
2 changed files with 25 additions and 3 deletions

View File

@ -9,6 +9,8 @@ import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._inductor.autotune_process import BenchmarkRequest
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -273,6 +275,28 @@ class TestSelectAlgorithm(TestCase):
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
def test_TritonTemplateCaller_str(self):
"""
Make sure str(TritonTemplateCaller) does not raise exceptions.
"""
module_path = "abc.py"
bmreq = BenchmarkRequest(
module_path=module_path,
module_cache_key=None,
kernel_name=None,
grid=None,
extra_args=None,
num_stages=None,
num_warps=None,
input_tensors=None,
output_tensor=None,
)
caller = select_algorithm.TritonTemplateCaller(
None, None, None, None, "extra", bmreq
)
caller_str = str(caller)
self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)")
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -552,9 +552,7 @@ class TritonTemplateCaller(ChoiceCaller):
return self.bmreq.benchmark(*args, output_tensor=out)
def __str__(self):
return (
f"TritonTemplateCaller({self.to_callable().__file__}, {self.debug_extra})"
)
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
def call_name(self):
return f"template_kernels.{self.name}"