mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[inductor] Remove _get_grid_fn_str (#146800)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146800 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
0d5fb0941f
commit
c2bf3be011
|
|
@ -113,7 +113,7 @@ def main() -> None:
|
|||
print("More than 100 items needs to be rebuild, run `ninja torch_python` first")
|
||||
sys.exit(-1)
|
||||
for idx, (name, cmd) in enumerate(build_plan):
|
||||
print(f"[{idx + 1 } / {len(build_plan)}] Building {name}")
|
||||
print(f"[{idx + 1} / {len(build_plan)}] Building {name}")
|
||||
if args.verbose:
|
||||
print(cmd)
|
||||
subprocess.check_call(["sh", "-c", cmd], cwd=BUILD_DIR)
|
||||
|
|
|
|||
|
|
@ -3615,9 +3615,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
if tree.prefix == "x" and self.no_x_dim:
|
||||
code.writeline("XBLOCK: tl.constexpr = 1")
|
||||
|
||||
def _get_grid_fn_str(self):
|
||||
return self._get_grid_fn().__name__
|
||||
|
||||
def _get_grid_fn(self):
|
||||
if self.cooperative_reduction:
|
||||
return cooperative_reduction_grid
|
||||
|
|
@ -3648,9 +3645,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
for ws in self.args.workspace_args:
|
||||
wrapper.generate_workspace_allocation(ws)
|
||||
|
||||
grid = wrapper.generate_default_grid(
|
||||
name, grid, grid_callable=self._get_grid_fn()
|
||||
)
|
||||
grid_fn = self._get_grid_fn()
|
||||
grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
|
|
@ -3659,7 +3655,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
gpu=current_device.type != "cpu",
|
||||
triton=True,
|
||||
arg_types=arg_types,
|
||||
grid_fn=self._get_grid_fn_str(),
|
||||
grid_fn=grid_fn.__name__,
|
||||
triton_meta=self.triton_meta,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -203,8 +203,5 @@ class TritonSplitScanKernel(TritonKernel):
|
|||
def _get_heuristic(self):
|
||||
return "split_scan"
|
||||
|
||||
def _get_grid_fn_str(self):
|
||||
return "split_scan_grid"
|
||||
|
||||
def _get_grid_fn(self):
|
||||
return split_scan_grid
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user