mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "limit fused kernel num args. (#113131)"
This reverts commit 7b442c2b0a.
Reverted https://github.com/pytorch/pytorch/pull/113131 on behalf of https://github.com/albanD due to Breaks lint on trunk ([comment](https://github.com/pytorch/pytorch/pull/113131#issuecomment-1817548349))
This commit is contained in:
parent
b53d47a719
commit
ff7c06a01b
|
|
@ -7712,24 +7712,6 @@ class CommonTemplate:
|
|||
b = torch.randn(65, 2**24, device=self.device)
|
||||
fn(a, b)
|
||||
|
||||
def test_fuse_large_params(self):
|
||||
def pt2_optimizer_step(optimizer):
|
||||
@torch.compile()
|
||||
def f():
|
||||
optimizer.step()
|
||||
|
||||
f()
|
||||
|
||||
params = [
|
||||
torch.rand(10, 10, dtype=torch.float32, device=self.device)
|
||||
for _ in range(194)
|
||||
]
|
||||
for p in params:
|
||||
p.grad = torch.rand_like(p)
|
||||
|
||||
o = torch.optim.AdamW(params)
|
||||
pt2_optimizer_step(o)
|
||||
|
||||
def test_adaptive_avg_pool1d_argmax(self):
|
||||
# https://github.com/pytorch/pytorch/issues/113013
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -2809,18 +2809,9 @@ class CppKernelProxy(CppKernel):
|
|||
|
||||
|
||||
class CppScheduling(BaseScheduling):
|
||||
# ctypes limits the number of args to 1024, refer to:
|
||||
# https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
|
||||
# We set a conservative threshold here.
|
||||
MAX_FUSED_KERNEL_ARGS_NUM = 500
|
||||
|
||||
def __init__(self, scheduler):
|
||||
self.scheduler = scheduler
|
||||
self.get_kernel_group()
|
||||
self._ready_to_flush = False
|
||||
|
||||
def _set_flush_status(self, status: bool):
|
||||
self._ready_to_flush = status
|
||||
|
||||
def group_fn(self, sizes):
|
||||
return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
|
||||
|
|
@ -2867,23 +2858,12 @@ class CppScheduling(BaseScheduling):
|
|||
|
||||
kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
|
||||
|
||||
args_num = self._get_scheduled_num_args()
|
||||
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
|
||||
self._set_flush_status(True)
|
||||
|
||||
def _get_scheduled_num_args(self):
|
||||
return self.kernel_group.get_num_args()
|
||||
|
||||
def ready_to_flush(self):
|
||||
return self._ready_to_flush
|
||||
|
||||
def codegen_sync(self):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
|
||||
self.get_kernel_group()
|
||||
self._set_flush_status(False)
|
||||
|
||||
|
||||
class KernelGroup:
|
||||
|
|
@ -2905,11 +2885,6 @@ class KernelGroup:
|
|||
ws = self.ws
|
||||
new_kernel.codegen_loops(code, ws)
|
||||
|
||||
def get_num_args(self):
|
||||
arg_defs, call_args, arg_types = self.args.cpp_argdefs()
|
||||
args_num = len(arg_defs)
|
||||
return args_num
|
||||
|
||||
def codegen_define_and_call(self, wrapper):
|
||||
self.stack.close()
|
||||
if not self.scheduled_nodes:
|
||||
|
|
|
|||
|
|
@ -2850,9 +2850,6 @@ class TritonScheduling(BaseScheduling):
|
|||
def flush(self):
|
||||
pass
|
||||
|
||||
def ready_to_flush(self) -> bool:
|
||||
return False
|
||||
|
||||
def benchmark_fused_nodes(self, nodes):
|
||||
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
||||
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
|
||||
|
|
|
|||
|
|
@ -2134,7 +2134,6 @@ class Scheduler:
|
|||
@dynamo_timed
|
||||
def codegen(self):
|
||||
for node in self.nodes:
|
||||
device = None
|
||||
try:
|
||||
log.debug(
|
||||
"Generating code for node %s with estimated runtime %f",
|
||||
|
|
@ -2190,9 +2189,6 @@ class Scheduler:
|
|||
|
||||
self.available_buffer_names.update(node.get_names())
|
||||
|
||||
if device is not None and self.get_backend(device).ready_to_flush():
|
||||
self.flush()
|
||||
|
||||
self.flush()
|
||||
|
||||
def is_unaligned_buffer(self, buf_name):
|
||||
|
|
@ -2249,13 +2245,6 @@ class BaseScheduling:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def ready_to_flush(self) -> bool:
|
||||
"""
|
||||
Check whether the backend is request scheduler flush the generated kernel.
|
||||
If not support, please return False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush the generated kernel and python wrapper code to the source code file.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user