[inductor] Remove _get_grid_fn_str (#146800)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146800
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel 2025-02-09 11:28:27 -08:00 committed by PyTorch MergeBot
parent 0d5fb0941f
commit c2bf3be011
3 changed files with 4 additions and 11 deletions

View File

@ -3615,9 +3615,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if tree.prefix == "x" and self.no_x_dim:
code.writeline("XBLOCK: tl.constexpr = 1")
def _get_grid_fn_str(self):
return self._get_grid_fn().__name__
def _get_grid_fn(self):
if self.cooperative_reduction:
return cooperative_reduction_grid
@ -3648,9 +3645,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
for ws in self.args.workspace_args:
wrapper.generate_workspace_allocation(ws)
grid = wrapper.generate_default_grid(
name, grid, grid_callable=self._get_grid_fn()
)
grid_fn = self._get_grid_fn()
grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn)
wrapper.generate_kernel_call(
name,
call_args,
@ -3659,7 +3655,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
gpu=current_device.type != "cpu",
triton=True,
arg_types=arg_types,
grid_fn=self._get_grid_fn_str(),
grid_fn=grid_fn.__name__,
triton_meta=self.triton_meta,
)

View File

@ -203,8 +203,5 @@ class TritonSplitScanKernel(TritonKernel):
def _get_heuristic(self):
return "split_scan"
def _get_grid_fn_str(self):
return "split_scan_grid"
def _get_grid_fn(self):
return split_scan_grid