From ff7c06a01bec97c897374a685f33bb7253df11ec Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 18 Nov 2023 16:14:08 +0000 Subject: [PATCH] Revert "limit fused kernel num args. (#113131)" This reverts commit 7b442c2b0ae0d9c944a777d7352135f370837c15. Reverted https://github.com/pytorch/pytorch/pull/113131 on behalf of https://github.com/albanD due to Breaks lint on trunk ([comment](https://github.com/pytorch/pytorch/pull/113131#issuecomment-1817548349)) --- test/inductor/test_torchinductor.py | 18 ------------------ torch/_inductor/codegen/cpp.py | 25 ------------------------- torch/_inductor/codegen/triton.py | 3 --- torch/_inductor/scheduler.py | 11 ----------- 4 files changed, 57 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 11ccbcda4bd..fbe6ef45a5b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7712,24 +7712,6 @@ class CommonTemplate: b = torch.randn(65, 2**24, device=self.device) fn(a, b) - def test_fuse_large_params(self): - def pt2_optimizer_step(optimizer): - @torch.compile() - def f(): - optimizer.step() - - f() - - params = [ - torch.rand(10, 10, dtype=torch.float32, device=self.device) - for _ in range(194) - ] - for p in params: - p.grad = torch.rand_like(p) - - o = torch.optim.AdamW(params) - pt2_optimizer_step(o) - def test_adaptive_avg_pool1d_argmax(self): # https://github.com/pytorch/pytorch/issues/113013 def fn(x): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 9216dc2573c..90f55ed8b55 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2809,18 +2809,9 @@ class CppKernelProxy(CppKernel): class CppScheduling(BaseScheduling): - # ctypes limits the number of args to 1024, refer to: - # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 - # We set a conservative threshold here. - MAX_FUSED_KERNEL_ARGS_NUM = 500 - def __init__(self, scheduler): self.scheduler = scheduler self.get_kernel_group() - self._ready_to_flush = False - - def _set_flush_status(self, status: bool): - self._ready_to_flush = status def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) @@ -2867,23 +2858,12 @@ class CppScheduling(BaseScheduling): kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) - args_num = self._get_scheduled_num_args() - if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: - self._set_flush_status(True) - - def _get_scheduled_num_args(self): - return self.kernel_group.get_num_args() - - def ready_to_flush(self): - return self._ready_to_flush - def codegen_sync(self): pass def flush(self): self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) self.get_kernel_group() - self._set_flush_status(False) class KernelGroup: @@ -2905,11 +2885,6 @@ class KernelGroup: ws = self.ws new_kernel.codegen_loops(code, ws) - def get_num_args(self): - arg_defs, call_args, arg_types = self.args.cpp_argdefs() - args_num = len(arg_defs) - return args_num - def codegen_define_and_call(self, wrapper): self.stack.close() if not self.scheduled_nodes: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b209b7b8db4..ee0fb32e051 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2850,9 +2850,6 @@ class TritonScheduling(BaseScheduling): def flush(self): pass - def ready_to_flush(self) -> bool: - return False - def benchmark_fused_nodes(self, nodes): _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = self.generate_node_schedule(nodes, numel, rnumel) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f8050de14c7..a40bd920953 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2134,7 +2134,6 @@ class Scheduler: @dynamo_timed def codegen(self): for node in self.nodes: - device = None try: log.debug( "Generating code for node %s with estimated runtime %f", @@ -2190,9 +2189,6 @@ class Scheduler: self.available_buffer_names.update(node.get_names()) - if device is not None and self.get_backend(device).ready_to_flush(): - self.flush() - self.flush() def is_unaligned_buffer(self, buf_name): @@ -2249,13 +2245,6 @@ class BaseScheduling: """ raise NotImplementedError() - def ready_to_flush(self) -> bool: - """ - Check whether the backend is request scheduler flush the generated kernel. - If not support, please return False. - """ - return False - def flush(self): """ Flush the generated kernel and python wrapper code to the source code file.