diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 75bcd534f65..b564af83043 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -9,6 +9,8 @@ import torch._inductor.select_algorithm as select_algorithm import torch.nn.functional as F from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.utils import counters +from torch._inductor.autotune_process import BenchmarkRequest + from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import HAS_CUDA @@ -273,6 +275,28 @@ class TestSelectAlgorithm(TestCase): # Autotuning checks correctness of each version self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + def test_TritonTemplateCaller_str(self): + """ + Make sure str(TritonTemplateCaller) does not raise exceptions. + """ + module_path = "abc.py" + bmreq = BenchmarkRequest( + module_path=module_path, + module_cache_key=None, + kernel_name=None, + grid=None, + extra_args=None, + num_stages=None, + num_warps=None, + input_tensors=None, + output_tensor=None, + ) + caller = select_algorithm.TritonTemplateCaller( + None, None, None, None, "extra", bmreq + ) + caller_str = str(caller) + self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)") + if __name__ == "__main__": from torch._inductor.utils import is_big_gpu diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 0a7d7f29842..4d2040daf1b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -552,9 +552,7 @@ class TritonTemplateCaller(ChoiceCaller): return self.bmreq.benchmark(*args, output_tensor=out) def __str__(self): - return ( - f"TritonTemplateCaller({self.to_callable().__file__}, {self.debug_extra})" - ) + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" def call_name(self): return f"template_kernels.{self.name}"