From c2bf3be0112e1c43cd8185dc563fea834ef1af8f Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 9 Feb 2025 11:28:27 -0800 Subject: [PATCH] [inductor] Remove _get_grid_fn_str (#146800) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146800 Approved by: https://github.com/yanboliang --- tools/build_with_debinfo.py | 2 +- torch/_inductor/codegen/triton.py | 10 +++------- torch/_inductor/codegen/triton_split_scan.py | 3 --- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 0c9553b963e..73c9dba0090 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -113,7 +113,7 @@ def main() -> None: print("More than 100 items needs to be rebuild, run `ninja torch_python` first") sys.exit(-1) for idx, (name, cmd) in enumerate(build_plan): - print(f"[{idx + 1 } / {len(build_plan)}] Building {name}") + print(f"[{idx + 1} / {len(build_plan)}] Building {name}") if args.verbose: print(cmd) subprocess.check_call(["sh", "-c", cmd], cwd=BUILD_DIR) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a8837baea39..0922297e208 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3615,9 +3615,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def _get_grid_fn_str(self): - return self._get_grid_fn().__name__ - def _get_grid_fn(self): if self.cooperative_reduction: return cooperative_reduction_grid @@ -3648,9 +3645,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): for ws in self.args.workspace_args: wrapper.generate_workspace_allocation(ws) - grid = wrapper.generate_default_grid( - name, grid, grid_callable=self._get_grid_fn() - ) + grid_fn = self._get_grid_fn() + grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn) wrapper.generate_kernel_call( name, call_args, @@ -3659,7 +3655,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): gpu=current_device.type != "cpu", triton=True, arg_types=arg_types, - grid_fn=self._get_grid_fn_str(), + grid_fn=grid_fn.__name__, triton_meta=self.triton_meta, ) diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 165a33e691f..d025a8120d1 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -203,8 +203,5 @@ class TritonSplitScanKernel(TritonKernel): def _get_heuristic(self): return "split_scan" - def _get_grid_fn_str(self): - return "split_scan_grid" - def _get_grid_fn(self): return split_scan_grid