Revert "limit fused kernel num args. (#113131)"

This reverts commit 7b442c2b0a.

Reverted https://github.com/pytorch/pytorch/pull/113131 on behalf of https://github.com/albanD due to Breaks lint on trunk ([comment](https://github.com/pytorch/pytorch/pull/113131#issuecomment-1817548349))
This commit is contained in:
PyTorch MergeBot 2023-11-18 16:14:08 +00:00
parent b53d47a719
commit ff7c06a01b
4 changed files with 0 additions and 57 deletions

View File

@ -7712,24 +7712,6 @@ class CommonTemplate:
b = torch.randn(65, 2**24, device=self.device)
fn(a, b)
def test_fuse_large_params(self):
def pt2_optimizer_step(optimizer):
@torch.compile()
def f():
optimizer.step()
f()
params = [
torch.rand(10, 10, dtype=torch.float32, device=self.device)
for _ in range(194)
]
for p in params:
p.grad = torch.rand_like(p)
o = torch.optim.AdamW(params)
pt2_optimizer_step(o)
def test_adaptive_avg_pool1d_argmax(self):
# https://github.com/pytorch/pytorch/issues/113013
def fn(x):

View File

@ -2809,18 +2809,9 @@ class CppKernelProxy(CppKernel):
class CppScheduling(BaseScheduling):
# ctypes limits the number of args to 1024, refer to:
# https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
# We set a conservative threshold here.
MAX_FUSED_KERNEL_ARGS_NUM = 500
def __init__(self, scheduler):
self.scheduler = scheduler
self.get_kernel_group()
self._ready_to_flush = False
def _set_flush_status(self, status: bool):
self._ready_to_flush = status
def group_fn(self, sizes):
return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
@ -2867,23 +2858,12 @@ class CppScheduling(BaseScheduling):
kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
args_num = self._get_scheduled_num_args()
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
self._set_flush_status(True)
def _get_scheduled_num_args(self):
return self.kernel_group.get_num_args()
def ready_to_flush(self):
return self._ready_to_flush
def codegen_sync(self):
pass
def flush(self):
self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
self.get_kernel_group()
self._set_flush_status(False)
class KernelGroup:
@ -2905,11 +2885,6 @@ class KernelGroup:
ws = self.ws
new_kernel.codegen_loops(code, ws)
def get_num_args(self):
arg_defs, call_args, arg_types = self.args.cpp_argdefs()
args_num = len(arg_defs)
return args_num
def codegen_define_and_call(self, wrapper):
self.stack.close()
if not self.scheduled_nodes:

View File

@ -2850,9 +2850,6 @@ class TritonScheduling(BaseScheduling):
def flush(self):
pass
def ready_to_flush(self) -> bool:
return False
def benchmark_fused_nodes(self, nodes):
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)

View File

@ -2134,7 +2134,6 @@ class Scheduler:
@dynamo_timed
def codegen(self):
for node in self.nodes:
device = None
try:
log.debug(
"Generating code for node %s with estimated runtime %f",
@ -2190,9 +2189,6 @@ class Scheduler:
self.available_buffer_names.update(node.get_names())
if device is not None and self.get_backend(device).ready_to_flush():
self.flush()
self.flush()
def is_unaligned_buffer(self, buf_name):
@ -2249,13 +2245,6 @@ class BaseScheduling:
"""
raise NotImplementedError()
def ready_to_flush(self) -> bool:
"""
Check whether the backend is request scheduler flush the generated kernel.
If not support, please return False.
"""
return False
def flush(self):
"""
Flush the generated kernel and python wrapper code to the source code file.